You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import hashlib
  17. from datetime import datetime
  18. import logging
  19. import peewee
  20. from werkzeug.security import generate_password_hash, check_password_hash
  21. from api.db import UserTenantRole
  22. from api.db.db_models import DB, UserTenant
  23. from api.db.db_models import User, Tenant
  24. from api.db.services.common_service import CommonService
  25. from api.utils import get_uuid, current_timestamp, datetime_format
  26. from api.db import StatusEnum
  27. from rag.settings import MINIO
  28. class UserService(CommonService):
  29. """Service class for managing user-related database operations.
  30. This class extends CommonService to provide specialized functionality for user management,
  31. including authentication, user creation, updates, and deletions.
  32. Attributes:
  33. model: The User model class for database operations.
  34. """
  35. model = User
  36. @classmethod
  37. @DB.connection_context()
  38. def query(cls, cols=None, reverse=None, order_by=None, **kwargs):
  39. if 'access_token' in kwargs:
  40. access_token = kwargs['access_token']
  41. # Reject empty, None, or whitespace-only access tokens
  42. if not access_token or not str(access_token).strip():
  43. logging.warning("UserService.query: Rejecting empty access_token query")
  44. return cls.model.select().where(cls.model.id == "INVALID_EMPTY_TOKEN") # Returns empty result
  45. # Reject tokens that are too short (should be UUID, 32+ chars)
  46. if len(str(access_token).strip()) < 32:
  47. logging.warning(f"UserService.query: Rejecting short access_token query: {len(str(access_token))} chars")
  48. return cls.model.select().where(cls.model.id == "INVALID_SHORT_TOKEN") # Returns empty result
  49. # Reject tokens that start with "INVALID_" (from logout)
  50. if str(access_token).startswith("INVALID_"):
  51. logging.warning("UserService.query: Rejecting invalidated access_token")
  52. return cls.model.select().where(cls.model.id == "INVALID_LOGOUT_TOKEN") # Returns empty result
  53. # Call parent query method for valid requests
  54. return super().query(cols=cols, reverse=reverse, order_by=order_by, **kwargs)
  55. @classmethod
  56. @DB.connection_context()
  57. def filter_by_id(cls, user_id):
  58. """Retrieve a user by their ID.
  59. Args:
  60. user_id: The unique identifier of the user.
  61. Returns:
  62. User object if found, None otherwise.
  63. """
  64. try:
  65. user = cls.model.select().where(cls.model.id == user_id).get()
  66. return user
  67. except peewee.DoesNotExist:
  68. return None
  69. @classmethod
  70. @DB.connection_context()
  71. def query_user(cls, email, password):
  72. """Authenticate a user with email and password.
  73. Args:
  74. email: User's email address.
  75. password: User's password in plain text.
  76. Returns:
  77. User object if authentication successful, None otherwise.
  78. """
  79. user = cls.model.select().where((cls.model.email == email),
  80. (cls.model.status == StatusEnum.VALID.value)).first()
  81. if user and check_password_hash(str(user.password), password):
  82. return user
  83. else:
  84. return None
  85. @classmethod
  86. @DB.connection_context()
  87. def save(cls, **kwargs):
  88. if "id" not in kwargs:
  89. kwargs["id"] = get_uuid()
  90. if "password" in kwargs:
  91. kwargs["password"] = generate_password_hash(
  92. str(kwargs["password"]))
  93. kwargs["create_time"] = current_timestamp()
  94. kwargs["create_date"] = datetime_format(datetime.now())
  95. kwargs["update_time"] = current_timestamp()
  96. kwargs["update_date"] = datetime_format(datetime.now())
  97. obj = cls.model(**kwargs).save(force_insert=True)
  98. return obj
  99. @classmethod
  100. @DB.connection_context()
  101. def delete_user(cls, user_ids, update_user_dict):
  102. with DB.atomic():
  103. cls.model.update({"status": 0}).where(
  104. cls.model.id.in_(user_ids)).execute()
  105. @classmethod
  106. @DB.connection_context()
  107. def update_user(cls, user_id, user_dict):
  108. with DB.atomic():
  109. if user_dict:
  110. user_dict["update_time"] = current_timestamp()
  111. user_dict["update_date"] = datetime_format(datetime.now())
  112. cls.model.update(user_dict).where(
  113. cls.model.id == user_id).execute()
  114. class TenantService(CommonService):
  115. """Service class for managing tenant-related database operations.
  116. This class extends CommonService to provide functionality for tenant management,
  117. including tenant information retrieval and credit management.
  118. Attributes:
  119. model: The Tenant model class for database operations.
  120. """
  121. model = Tenant
  122. @classmethod
  123. @DB.connection_context()
  124. def get_info_by(cls, user_id):
  125. fields = [
  126. cls.model.id.alias("tenant_id"),
  127. cls.model.name,
  128. cls.model.llm_id,
  129. cls.model.embd_id,
  130. cls.model.rerank_id,
  131. cls.model.asr_id,
  132. cls.model.img2txt_id,
  133. cls.model.tts_id,
  134. cls.model.parser_ids,
  135. UserTenant.role]
  136. return list(cls.model.select(*fields)
  137. .join(UserTenant, on=((cls.model.id == UserTenant.tenant_id) & (UserTenant.user_id == user_id) & (UserTenant.status == StatusEnum.VALID.value) & (UserTenant.role == UserTenantRole.OWNER)))
  138. .where(cls.model.status == StatusEnum.VALID.value).dicts())
  139. @classmethod
  140. @DB.connection_context()
  141. def get_joined_tenants_by_user_id(cls, user_id):
  142. fields = [
  143. cls.model.id.alias("tenant_id"),
  144. cls.model.name,
  145. cls.model.llm_id,
  146. cls.model.embd_id,
  147. cls.model.asr_id,
  148. cls.model.img2txt_id,
  149. UserTenant.role]
  150. return list(cls.model.select(*fields)
  151. .join(UserTenant, on=((cls.model.id == UserTenant.tenant_id) & (UserTenant.user_id == user_id) & (UserTenant.status == StatusEnum.VALID.value) & (UserTenant.role == UserTenantRole.NORMAL)))
  152. .where(cls.model.status == StatusEnum.VALID.value).dicts())
  153. @classmethod
  154. @DB.connection_context()
  155. def decrease(cls, user_id, num):
  156. num = cls.model.update(credit=cls.model.credit - num).where(
  157. cls.model.id == user_id).execute()
  158. if num == 0:
  159. raise LookupError("Tenant not found which is supposed to be there")
  160. @classmethod
  161. @DB.connection_context()
  162. def user_gateway(cls, tenant_id):
  163. hashobj = hashlib.sha256(tenant_id.encode("utf-8"))
  164. return int(hashobj.hexdigest(), 16)%len(MINIO)
  165. class UserTenantService(CommonService):
  166. """Service class for managing user-tenant relationship operations.
  167. This class extends CommonService to handle the many-to-many relationship
  168. between users and tenants, managing user roles and tenant memberships.
  169. Attributes:
  170. model: The UserTenant model class for database operations.
  171. """
  172. model = UserTenant
  173. @classmethod
  174. @DB.connection_context()
  175. def filter_by_id(cls, user_tenant_id):
  176. try:
  177. user_tenant = cls.model.select().where((cls.model.id == user_tenant_id) & (cls.model.status == StatusEnum.VALID.value)).get()
  178. return user_tenant
  179. except peewee.DoesNotExist:
  180. return None
  181. @classmethod
  182. @DB.connection_context()
  183. def save(cls, **kwargs):
  184. if "id" not in kwargs:
  185. kwargs["id"] = get_uuid()
  186. obj = cls.model(**kwargs).save(force_insert=True)
  187. return obj
  188. @classmethod
  189. @DB.connection_context()
  190. def get_by_tenant_id(cls, tenant_id):
  191. fields = [
  192. cls.model.id,
  193. cls.model.user_id,
  194. cls.model.status,
  195. cls.model.role,
  196. User.nickname,
  197. User.email,
  198. User.avatar,
  199. User.is_authenticated,
  200. User.is_active,
  201. User.is_anonymous,
  202. User.status,
  203. User.update_date,
  204. User.is_superuser]
  205. return list(cls.model.select(*fields)
  206. .join(User, on=((cls.model.user_id == User.id) & (cls.model.status == StatusEnum.VALID.value) & (cls.model.role != UserTenantRole.OWNER)))
  207. .where(cls.model.tenant_id == tenant_id)
  208. .dicts())
  209. @classmethod
  210. @DB.connection_context()
  211. def get_tenants_by_user_id(cls, user_id):
  212. fields = [
  213. cls.model.tenant_id,
  214. cls.model.role,
  215. User.nickname,
  216. User.email,
  217. User.avatar,
  218. User.update_date
  219. ]
  220. return list(cls.model.select(*fields)
  221. .join(User, on=((cls.model.tenant_id == User.id) & (UserTenant.user_id == user_id) & (UserTenant.status == StatusEnum.VALID.value)))
  222. .where(cls.model.status == StatusEnum.VALID.value).dicts())
  223. @classmethod
  224. @DB.connection_context()
  225. def get_num_members(cls, user_id: str):
  226. cnt_members = cls.model.select(peewee.fn.COUNT(cls.model.id)).where(cls.model.tenant_id == user_id).scalar()
  227. return cnt_members
  228. @classmethod
  229. @DB.connection_context()
  230. def filter_by_tenant_and_user_id(cls, tenant_id, user_id):
  231. try:
  232. user_tenant = cls.model.select().where(
  233. (cls.model.tenant_id == tenant_id) & (cls.model.status == StatusEnum.VALID.value) &
  234. (cls.model.user_id == user_id)
  235. ).first()
  236. return user_tenant
  237. except peewee.DoesNotExist:
  238. return None