您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

knowledgebase_service.py 16KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. from datetime import datetime
  17. from peewee import fn
  18. from api.db import StatusEnum, TenantPermission
  19. from api.db.db_models import DB, Document, Knowledgebase, Tenant, User, UserTenant
  20. from api.db.services.common_service import CommonService
  21. from api.utils import current_timestamp, datetime_format
  22. class KnowledgebaseService(CommonService):
  23. """Service class for managing knowledge base operations.
  24. This class extends CommonService to provide specialized functionality for knowledge base
  25. management, including document parsing status tracking, access control, and configuration
  26. management. It handles operations such as listing, creating, updating, and deleting
  27. knowledge bases, as well as managing their associated documents and permissions.
  28. The class implements a comprehensive set of methods for:
  29. - Document parsing status verification
  30. - Knowledge base access control
  31. - Parser configuration management
  32. - Tenant-based knowledge base organization
  33. Attributes:
  34. model: The Knowledgebase model class for database operations.
  35. """
  36. model = Knowledgebase
  37. @classmethod
  38. @DB.connection_context()
  39. def accessible4deletion(cls, kb_id, user_id):
  40. """Check if a knowledge base can be deleted by a specific user.
  41. This method verifies whether a user has permission to delete a knowledge base
  42. by checking if they are the creator of that knowledge base.
  43. Args:
  44. kb_id (str): The unique identifier of the knowledge base to check.
  45. user_id (str): The unique identifier of the user attempting the deletion.
  46. Returns:
  47. bool: True if the user has permission to delete the knowledge base,
  48. False if the user doesn't have permission or the knowledge base doesn't exist.
  49. Example:
  50. >>> KnowledgebaseService.accessible4deletion("kb123", "user456")
  51. True
  52. Note:
  53. - This method only checks creator permissions
  54. - A return value of False can mean either:
  55. 1. The knowledge base doesn't exist
  56. 2. The user is not the creator of the knowledge base
  57. """
  58. # Check if a knowledge base can be deleted by a user
  59. docs = cls.model.select(
  60. cls.model.id).where(cls.model.id == kb_id, cls.model.created_by == user_id).paginate(0, 1)
  61. docs = docs.dicts()
  62. if not docs:
  63. return False
  64. return True
  65. @classmethod
  66. @DB.connection_context()
  67. def is_parsed_done(cls, kb_id):
  68. # Check if all documents in the knowledge base have completed parsing
  69. #
  70. # Args:
  71. # kb_id: Knowledge base ID
  72. #
  73. # Returns:
  74. # If all documents are parsed successfully, returns (True, None)
  75. # If any document is not fully parsed, returns (False, error_message)
  76. from api.db import TaskStatus
  77. from api.db.services.document_service import DocumentService
  78. # Get knowledge base information
  79. kbs = cls.query(id=kb_id)
  80. if not kbs:
  81. return False, "Knowledge base not found"
  82. kb = kbs[0]
  83. # Get all documents in the knowledge base
  84. docs, _ = DocumentService.get_by_kb_id(kb_id, 1, 1000, "create_time", True, "", [], [])
  85. # Check parsing status of each document
  86. for doc in docs:
  87. # If document is being parsed, don't allow chat creation
  88. if doc['run'] == TaskStatus.RUNNING.value or doc['run'] == TaskStatus.CANCEL.value or doc['run'] == TaskStatus.FAIL.value:
  89. return False, f"Document '{doc['name']}' in dataset '{kb.name}' is still being parsed. Please wait until all documents are parsed before starting a chat."
  90. # If document is not yet parsed and has no chunks, don't allow chat creation
  91. if doc['run'] == TaskStatus.UNSTART.value and doc['chunk_num'] == 0:
  92. return False, f"Document '{doc['name']}' in dataset '{kb.name}' has not been parsed yet. Please parse all documents before starting a chat."
  93. return True, None
  94. @classmethod
  95. @DB.connection_context()
  96. def list_documents_by_ids(cls, kb_ids):
  97. # Get document IDs associated with given knowledge base IDs
  98. # Args:
  99. # kb_ids: List of knowledge base IDs
  100. # Returns:
  101. # List of document IDs
  102. doc_ids = cls.model.select(Document.id.alias("document_id")).join(Document, on=(cls.model.id == Document.kb_id)).where(
  103. cls.model.id.in_(kb_ids)
  104. )
  105. doc_ids = list(doc_ids.dicts())
  106. doc_ids = [doc["document_id"] for doc in doc_ids]
  107. return doc_ids
  108. @classmethod
  109. @DB.connection_context()
  110. def get_by_tenant_ids(cls, joined_tenant_ids, user_id,
  111. page_number, items_per_page,
  112. orderby, desc, keywords,
  113. parser_id=None
  114. ):
  115. # Get knowledge bases by tenant IDs with pagination and filtering
  116. # Args:
  117. # joined_tenant_ids: List of tenant IDs
  118. # user_id: Current user ID
  119. # page_number: Page number for pagination
  120. # items_per_page: Number of items per page
  121. # orderby: Field to order by
  122. # desc: Boolean indicating descending order
  123. # keywords: Search keywords
  124. # parser_id: Optional parser ID filter
  125. # Returns:
  126. # Tuple of (knowledge_base_list, total_count)
  127. fields = [
  128. cls.model.id,
  129. cls.model.avatar,
  130. cls.model.name,
  131. cls.model.language,
  132. cls.model.description,
  133. cls.model.tenant_id,
  134. cls.model.permission,
  135. cls.model.doc_num,
  136. cls.model.token_num,
  137. cls.model.chunk_num,
  138. cls.model.parser_id,
  139. cls.model.embd_id,
  140. User.nickname,
  141. User.avatar.alias('tenant_avatar'),
  142. cls.model.update_time
  143. ]
  144. if keywords:
  145. kbs = cls.model.select(*fields).join(User, on=(cls.model.tenant_id == User.id)).where(
  146. ((cls.model.tenant_id.in_(joined_tenant_ids) & (cls.model.permission ==
  147. TenantPermission.TEAM.value)) | (
  148. cls.model.tenant_id == user_id))
  149. & (cls.model.status == StatusEnum.VALID.value),
  150. (fn.LOWER(cls.model.name).contains(keywords.lower()))
  151. )
  152. else:
  153. kbs = cls.model.select(*fields).join(User, on=(cls.model.tenant_id == User.id)).where(
  154. ((cls.model.tenant_id.in_(joined_tenant_ids) & (cls.model.permission ==
  155. TenantPermission.TEAM.value)) | (
  156. cls.model.tenant_id == user_id))
  157. & (cls.model.status == StatusEnum.VALID.value)
  158. )
  159. if parser_id:
  160. kbs = kbs.where(cls.model.parser_id == parser_id)
  161. if desc:
  162. kbs = kbs.order_by(cls.model.getter_by(orderby).desc())
  163. else:
  164. kbs = kbs.order_by(cls.model.getter_by(orderby).asc())
  165. count = kbs.count()
  166. if page_number and items_per_page:
  167. kbs = kbs.paginate(page_number, items_per_page)
  168. return list(kbs.dicts()), count
  169. @classmethod
  170. @DB.connection_context()
  171. def get_kb_ids(cls, tenant_id):
  172. # Get all knowledge base IDs for a tenant
  173. # Args:
  174. # tenant_id: Tenant ID
  175. # Returns:
  176. # List of knowledge base IDs
  177. fields = [
  178. cls.model.id,
  179. ]
  180. kbs = cls.model.select(*fields).where(cls.model.tenant_id == tenant_id)
  181. kb_ids = [kb.id for kb in kbs]
  182. return kb_ids
  183. @classmethod
  184. @DB.connection_context()
  185. def get_detail(cls, kb_id):
  186. # Get detailed information about a knowledge base
  187. # Args:
  188. # kb_id: Knowledge base ID
  189. # Returns:
  190. # Dictionary containing knowledge base details
  191. fields = [
  192. cls.model.id,
  193. cls.model.embd_id,
  194. cls.model.avatar,
  195. cls.model.name,
  196. cls.model.language,
  197. cls.model.description,
  198. cls.model.permission,
  199. cls.model.doc_num,
  200. cls.model.token_num,
  201. cls.model.chunk_num,
  202. cls.model.parser_id,
  203. cls.model.parser_config,
  204. cls.model.pagerank,
  205. cls.model.create_time,
  206. cls.model.update_time
  207. ]
  208. kbs = cls.model.select(*fields).join(Tenant, on=(
  209. (Tenant.id == cls.model.tenant_id) & (Tenant.status == StatusEnum.VALID.value))).where(
  210. (cls.model.id == kb_id),
  211. (cls.model.status == StatusEnum.VALID.value)
  212. )
  213. if not kbs:
  214. return
  215. d = kbs[0].to_dict()
  216. return d
  217. @classmethod
  218. @DB.connection_context()
  219. def update_parser_config(cls, id, config):
  220. # Update parser configuration for a knowledge base
  221. # Args:
  222. # id: Knowledge base ID
  223. # config: New parser configuration
  224. e, m = cls.get_by_id(id)
  225. if not e:
  226. raise LookupError(f"knowledgebase({id}) not found.")
  227. def dfs_update(old, new):
  228. # Deep update of nested configuration
  229. for k, v in new.items():
  230. if k not in old:
  231. old[k] = v
  232. continue
  233. if isinstance(v, dict):
  234. assert isinstance(old[k], dict)
  235. dfs_update(old[k], v)
  236. elif isinstance(v, list):
  237. assert isinstance(old[k], list)
  238. old[k] = list(set(old[k] + v))
  239. else:
  240. old[k] = v
  241. dfs_update(m.parser_config, config)
  242. cls.update_by_id(id, {"parser_config": m.parser_config})
  243. @classmethod
  244. @DB.connection_context()
  245. def get_field_map(cls, ids):
  246. # Get field mappings for knowledge bases
  247. # Args:
  248. # ids: List of knowledge base IDs
  249. # Returns:
  250. # Dictionary of field mappings
  251. conf = {}
  252. for k in cls.get_by_ids(ids):
  253. if k.parser_config and "field_map" in k.parser_config:
  254. conf.update(k.parser_config["field_map"])
  255. return conf
  256. @classmethod
  257. @DB.connection_context()
  258. def get_by_name(cls, kb_name, tenant_id):
  259. # Get knowledge base by name and tenant ID
  260. # Args:
  261. # kb_name: Knowledge base name
  262. # tenant_id: Tenant ID
  263. # Returns:
  264. # Tuple of (exists, knowledge_base)
  265. kb = cls.model.select().where(
  266. (cls.model.name == kb_name)
  267. & (cls.model.tenant_id == tenant_id)
  268. & (cls.model.status == StatusEnum.VALID.value)
  269. )
  270. if kb:
  271. return True, kb[0]
  272. return False, None
  273. @classmethod
  274. @DB.connection_context()
  275. def get_all_ids(cls):
  276. # Get all knowledge base IDs
  277. # Returns:
  278. # List of all knowledge base IDs
  279. return [m["id"] for m in cls.model.select(cls.model.id).dicts()]
  280. @classmethod
  281. @DB.connection_context()
  282. def get_list(cls, joined_tenant_ids, user_id,
  283. page_number, items_per_page, orderby, desc, id, name):
  284. # Get list of knowledge bases with filtering and pagination
  285. # Args:
  286. # joined_tenant_ids: List of tenant IDs
  287. # user_id: Current user ID
  288. # page_number: Page number for pagination
  289. # items_per_page: Number of items per page
  290. # orderby: Field to order by
  291. # desc: Boolean indicating descending order
  292. # id: Optional ID filter
  293. # name: Optional name filter
  294. # Returns:
  295. # List of knowledge bases
  296. kbs = cls.model.select()
  297. if id:
  298. kbs = kbs.where(cls.model.id == id)
  299. if name:
  300. kbs = kbs.where(cls.model.name == name)
  301. kbs = kbs.where(
  302. ((cls.model.tenant_id.in_(joined_tenant_ids) & (cls.model.permission ==
  303. TenantPermission.TEAM.value)) | (
  304. cls.model.tenant_id == user_id))
  305. & (cls.model.status == StatusEnum.VALID.value)
  306. )
  307. if desc:
  308. kbs = kbs.order_by(cls.model.getter_by(orderby).desc())
  309. else:
  310. kbs = kbs.order_by(cls.model.getter_by(orderby).asc())
  311. kbs = kbs.paginate(page_number, items_per_page)
  312. return list(kbs.dicts())
  313. @classmethod
  314. @DB.connection_context()
  315. def accessible(cls, kb_id, user_id):
  316. # Check if a knowledge base is accessible by a user
  317. # Args:
  318. # kb_id: Knowledge base ID
  319. # user_id: User ID
  320. # Returns:
  321. # Boolean indicating accessibility
  322. docs = cls.model.select(
  323. cls.model.id).join(UserTenant, on=(UserTenant.tenant_id == Knowledgebase.tenant_id)
  324. ).where(cls.model.id == kb_id, UserTenant.user_id == user_id).paginate(0, 1)
  325. docs = docs.dicts()
  326. if not docs:
  327. return False
  328. return True
  329. @classmethod
  330. @DB.connection_context()
  331. def get_kb_by_id(cls, kb_id, user_id):
  332. # Get knowledge base by ID and user ID
  333. # Args:
  334. # kb_id: Knowledge base ID
  335. # user_id: User ID
  336. # Returns:
  337. # List containing knowledge base information
  338. kbs = cls.model.select().join(UserTenant, on=(UserTenant.tenant_id == Knowledgebase.tenant_id)
  339. ).where(cls.model.id == kb_id, UserTenant.user_id == user_id).paginate(0, 1)
  340. kbs = kbs.dicts()
  341. return list(kbs)
  342. @classmethod
  343. @DB.connection_context()
  344. def get_kb_by_name(cls, kb_name, user_id):
  345. # Get knowledge base by name and user ID
  346. # Args:
  347. # kb_name: Knowledge base name
  348. # user_id: User ID
  349. # Returns:
  350. # List containing knowledge base information
  351. kbs = cls.model.select().join(UserTenant, on=(UserTenant.tenant_id == Knowledgebase.tenant_id)
  352. ).where(cls.model.name == kb_name, UserTenant.user_id == user_id).paginate(0, 1)
  353. kbs = kbs.dicts()
  354. return list(kbs)
  355. @classmethod
  356. @DB.connection_context()
  357. def atomic_increase_doc_num_by_id(cls, kb_id):
  358. data = {}
  359. data["update_time"] = current_timestamp()
  360. data["update_date"] = datetime_format(datetime.now())
  361. data["doc_num"] = cls.model.doc_num + 1
  362. num = cls.model.update(data).where(cls.model.id == kb_id).execute()
  363. return num
  364. @classmethod
  365. @DB.connection_context()
  366. def update_document_number_in_init(cls, kb_id, doc_num):
  367. """
  368. Only use this function when init system
  369. """
  370. ok, kb = cls.get_by_id(kb_id)
  371. if not ok:
  372. return
  373. kb.doc_num = doc_num
  374. dirty_fields = kb.dirty_fields
  375. if cls.model._meta.combined.get("update_time") in dirty_fields:
  376. dirty_fields.remove(cls.model._meta.combined["update_time"])
  377. if cls.model._meta.combined.get("update_date") in dirty_fields:
  378. dirty_fields.remove(cls.model._meta.combined["update_date"])
  379. try:
  380. kb.save(only=dirty_fields)
  381. except ValueError as e:
  382. if str(e) == "no data to save!":
  383. pass # that's OK
  384. else:
  385. raise e