You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

db_models.py 41KB


  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import hashlib
  17. import inspect
  18. import logging
  19. import operator
  20. import os
  21. import sys
  22. import time
  23. import typing
  24. from enum import Enum
  25. from functools import wraps
  26. from flask_login import UserMixin
  27. from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
  28. from peewee import BigIntegerField, BooleanField, CharField, CompositeKey, DateTimeField, Field, FloatField, IntegerField, Metadata, Model, TextField
  29. from playhouse.migrate import MySQLMigrator, PostgresqlMigrator, migrate
  30. from playhouse.pool import PooledMySQLDatabase, PooledPostgresqlDatabase
  31. from api import settings, utils
  32. from api.db import ParserType, SerializedType
  33. def singleton(cls, *args, **kw):
  34. instances = {}
  35. def _singleton():
  36. key = str(cls) + str(os.getpid())
  37. if key not in instances:
  38. instances[key] = cls(*args, **kw)
  39. return instances[key]
  40. return _singleton
  41. CONTINUOUS_FIELD_TYPE = {IntegerField, FloatField, DateTimeField}
  42. AUTO_DATE_TIMESTAMP_FIELD_PREFIX = {"create", "start", "end", "update", "read_access", "write_access"}
  43. class TextFieldType(Enum):
  44. MYSQL = "LONGTEXT"
  45. POSTGRES = "TEXT"
  46. class LongTextField(TextField):
  47. field_type = TextFieldType[settings.DATABASE_TYPE.upper()].value
  48. class JSONField(LongTextField):
  49. default_value = {}
  50. def __init__(self, object_hook=None, object_pairs_hook=None, **kwargs):
  51. self._object_hook = object_hook
  52. self._object_pairs_hook = object_pairs_hook
  53. super().__init__(**kwargs)
  54. def db_value(self, value):
  55. if value is None:
  56. value = self.default_value
  57. return utils.json_dumps(value)
  58. def python_value(self, value):
  59. if not value:
  60. return self.default_value
  61. return utils.json_loads(value, object_hook=self._object_hook, object_pairs_hook=self._object_pairs_hook)
  62. class ListField(JSONField):
  63. default_value = []
  64. class SerializedField(LongTextField):
  65. def __init__(self, serialized_type=SerializedType.PICKLE, object_hook=None, object_pairs_hook=None, **kwargs):
  66. self._serialized_type = serialized_type
  67. self._object_hook = object_hook
  68. self._object_pairs_hook = object_pairs_hook
  69. super().__init__(**kwargs)
  70. def db_value(self, value):
  71. if self._serialized_type == SerializedType.PICKLE:
  72. return utils.serialize_b64(value, to_str=True)
  73. elif self._serialized_type == SerializedType.JSON:
  74. if value is None:
  75. return None
  76. return utils.json_dumps(value, with_type=True)
  77. else:
  78. raise ValueError(f"the serialized type {self._serialized_type} is not supported")
  79. def python_value(self, value):
  80. if self._serialized_type == SerializedType.PICKLE:
  81. return utils.deserialize_b64(value)
  82. elif self._serialized_type == SerializedType.JSON:
  83. if value is None:
  84. return {}
  85. return utils.json_loads(value, object_hook=self._object_hook, object_pairs_hook=self._object_pairs_hook)
  86. else:
  87. raise ValueError(f"the serialized type {self._serialized_type} is not supported")
  88. def is_continuous_field(cls: typing.Type) -> bool:
  89. if cls in CONTINUOUS_FIELD_TYPE:
  90. return True
  91. for p in cls.__bases__:
  92. if p in CONTINUOUS_FIELD_TYPE:
  93. return True
  94. elif p is not Field and p is not object:
  95. if is_continuous_field(p):
  96. return True
  97. else:
  98. return False
  99. def auto_date_timestamp_field():
  100. return {f"{f}_time" for f in AUTO_DATE_TIMESTAMP_FIELD_PREFIX}
  101. def auto_date_timestamp_db_field():
  102. return {f"f_{f}_time" for f in AUTO_DATE_TIMESTAMP_FIELD_PREFIX}
  103. def remove_field_name_prefix(field_name):
  104. return field_name[2:] if field_name.startswith("f_") else field_name
  105. class BaseModel(Model):
  106. create_time = BigIntegerField(null=True, index=True)
  107. create_date = DateTimeField(null=True, index=True)
  108. update_time = BigIntegerField(null=True, index=True)
  109. update_date = DateTimeField(null=True, index=True)
  110. def to_json(self):
  111. # This function is obsolete
  112. return self.to_dict()
  113. def to_dict(self):
  114. return self.__dict__["__data__"]
  115. def to_human_model_dict(self, only_primary_with: list = None):
  116. model_dict = self.__dict__["__data__"]
  117. if not only_primary_with:
  118. return {remove_field_name_prefix(k): v for k, v in model_dict.items()}
  119. human_model_dict = {}
  120. for k in self._meta.primary_key.field_names:
  121. human_model_dict[remove_field_name_prefix(k)] = model_dict[k]
  122. for k in only_primary_with:
  123. human_model_dict[k] = model_dict[f"f_{k}"]
  124. return human_model_dict
  125. @property
  126. def meta(self) -> Metadata:
  127. return self._meta
  128. @classmethod
  129. def get_primary_keys_name(cls):
  130. return cls._meta.primary_key.field_names if isinstance(cls._meta.primary_key, CompositeKey) else [cls._meta.primary_key.name]
  131. @classmethod
  132. def getter_by(cls, attr):
  133. return operator.attrgetter(attr)(cls)
  134. @classmethod
  135. def query(cls, reverse=None, order_by=None, **kwargs):
  136. filters = []
  137. for f_n, f_v in kwargs.items():
  138. attr_name = "%s" % f_n
  139. if not hasattr(cls, attr_name) or f_v is None:
  140. continue
  141. if type(f_v) in {list, set}:
  142. f_v = list(f_v)
  143. if is_continuous_field(type(getattr(cls, attr_name))):
  144. if len(f_v) == 2:
  145. for i, v in enumerate(f_v):
  146. if isinstance(v, str) and f_n in auto_date_timestamp_field():
  147. # time type: %Y-%m-%d %H:%M:%S
  148. f_v[i] = utils.date_string_to_timestamp(v)
  149. lt_value = f_v[0]
  150. gt_value = f_v[1]
  151. if lt_value is not None and gt_value is not None:
  152. filters.append(cls.getter_by(attr_name).between(lt_value, gt_value))
  153. elif lt_value is not None:
  154. filters.append(operator.attrgetter(attr_name)(cls) >= lt_value)
  155. elif gt_value is not None:
  156. filters.append(operator.attrgetter(attr_name)(cls) <= gt_value)
  157. else:
  158. filters.append(operator.attrgetter(attr_name)(cls) << f_v)
  159. else:
  160. filters.append(operator.attrgetter(attr_name)(cls) == f_v)
  161. if filters:
  162. query_records = cls.select().where(*filters)
  163. if reverse is not None:
  164. if not order_by or not hasattr(cls, f"{order_by}"):
  165. order_by = "create_time"
  166. if reverse is True:
  167. query_records = query_records.order_by(cls.getter_by(f"{order_by}").desc())
  168. elif reverse is False:
  169. query_records = query_records.order_by(cls.getter_by(f"{order_by}").asc())
  170. return [query_record for query_record in query_records]
  171. else:
  172. return []
  173. @classmethod
  174. def insert(cls, __data=None, **insert):
  175. if isinstance(__data, dict) and __data:
  176. __data[cls._meta.combined["create_time"]] = utils.current_timestamp()
  177. if insert:
  178. insert["create_time"] = utils.current_timestamp()
  179. return super().insert(__data, **insert)
  180. # update and insert will call this method
  181. @classmethod
  182. def _normalize_data(cls, data, kwargs):
  183. normalized = super()._normalize_data(data, kwargs)
  184. if not normalized:
  185. return {}
  186. normalized[cls._meta.combined["update_time"]] = utils.current_timestamp()
  187. for f_n in AUTO_DATE_TIMESTAMP_FIELD_PREFIX:
  188. if {f"{f_n}_time", f"{f_n}_date"}.issubset(cls._meta.combined.keys()) and cls._meta.combined[f"{f_n}_time"] in normalized and normalized[cls._meta.combined[f"{f_n}_time"]] is not None:
  189. normalized[cls._meta.combined[f"{f_n}_date"]] = utils.timestamp_to_date(normalized[cls._meta.combined[f"{f_n}_time"]])
  190. return normalized
  191. class JsonSerializedField(SerializedField):
  192. def __init__(self, object_hook=utils.from_dict_hook, object_pairs_hook=None, **kwargs):
  193. super(JsonSerializedField, self).__init__(serialized_type=SerializedType.JSON, object_hook=object_hook, object_pairs_hook=object_pairs_hook, **kwargs)
  194. class RetryingPooledMySQLDatabase(PooledMySQLDatabase):
  195. def __init__(self, *args, **kwargs):
  196. self.max_retries = kwargs.pop('max_retries', 5)
  197. self.retry_delay = kwargs.pop('retry_delay', 1)
  198. super().__init__(*args, **kwargs)
  199. def execute_sql(self, sql, params=None, commit=True):
  200. from peewee import OperationalError
  201. for attempt in range(self.max_retries + 1):
  202. try:
  203. return super().execute_sql(sql, params, commit)
  204. except OperationalError as e:
  205. if e.args[0] in (2013, 2006) and attempt < self.max_retries:
  206. logging.warning(
  207. f"Lost connection (attempt {attempt+1}/{self.max_retries}): {e}"
  208. )
  209. self._handle_connection_loss()
  210. time.sleep(self.retry_delay * (2 ** attempt))
  211. else:
  212. logging.error(f"DB execution failure: {e}")
  213. raise
  214. return None
  215. def _handle_connection_loss(self):
  216. self.close_all()
  217. self.connect()
  218. def begin(self):
  219. from peewee import OperationalError
  220. for attempt in range(self.max_retries + 1):
  221. try:
  222. return super().begin()
  223. except OperationalError as e:
  224. if e.args[0] in (2013, 2006) and attempt < self.max_retries:
  225. logging.warning(
  226. f"Lost connection during transaction (attempt {attempt+1}/{self.max_retries})"
  227. )
  228. self._handle_connection_loss()
  229. time.sleep(self.retry_delay * (2 ** attempt))
  230. else:
  231. raise
  232. class PooledDatabase(Enum):
  233. MYSQL = RetryingPooledMySQLDatabase
  234. POSTGRES = PooledPostgresqlDatabase
  235. class DatabaseMigrator(Enum):
  236. MYSQL = MySQLMigrator
  237. POSTGRES = PostgresqlMigrator
  238. @singleton
  239. class BaseDataBase:
  240. def __init__(self):
  241. database_config = settings.DATABASE.copy()
  242. db_name = database_config.pop("name")
  243. self.database_connection = PooledDatabase[settings.DATABASE_TYPE.upper()].value(db_name, **database_config)
  244. logging.info("init database on cluster mode successfully")
  245. def with_retry(max_retries=3, retry_delay=1.0):
  246. """Decorator: Add retry mechanism to database operations
  247. Args:
  248. max_retries (int): maximum number of retries
  249. retry_delay (float): initial retry delay (seconds), will increase exponentially
  250. Returns:
  251. decorated function
  252. """
  253. def decorator(func):
  254. @wraps(func)
  255. def wrapper(*args, **kwargs):
  256. last_exception = None
  257. for retry in range(max_retries):
  258. try:
  259. return func(*args, **kwargs)
  260. except Exception as e:
  261. last_exception = e
  262. # get self and method name for logging
  263. self_obj = args[0] if args else None
  264. func_name = func.__name__
  265. lock_name = getattr(self_obj, "lock_name", "unknown") if self_obj else "unknown"
  266. if retry < max_retries - 1:
  267. current_delay = retry_delay * (2**retry)
  268. logging.warning(f"{func_name} {lock_name} failed: {str(e)}, retrying ({retry + 1}/{max_retries})")
  269. time.sleep(current_delay)
  270. else:
  271. logging.error(f"{func_name} {lock_name} failed after all attempts: {str(e)}")
  272. if last_exception:
  273. raise last_exception
  274. return False
  275. return wrapper
  276. return decorator
  277. class PostgresDatabaseLock:
  278. def __init__(self, lock_name, timeout=10, db=None):
  279. self.lock_name = lock_name
  280. self.lock_id = int(hashlib.md5(lock_name.encode()).hexdigest(), 16) % (2**31 - 1)
  281. self.timeout = int(timeout)
  282. self.db = db if db else DB
  283. @with_retry(max_retries=3, retry_delay=1.0)
  284. def lock(self):
  285. cursor = self.db.execute_sql("SELECT pg_try_advisory_lock(%s)", (self.lock_id,))
  286. ret = cursor.fetchone()
  287. if ret[0] == 0:
  288. raise Exception(f"acquire postgres lock {self.lock_name} timeout")
  289. elif ret[0] == 1:
  290. return True
  291. else:
  292. raise Exception(f"failed to acquire lock {self.lock_name}")
  293. @with_retry(max_retries=3, retry_delay=1.0)
  294. def unlock(self):
  295. cursor = self.db.execute_sql("SELECT pg_advisory_unlock(%s)", (self.lock_id,))
  296. ret = cursor.fetchone()
  297. if ret[0] == 0:
  298. raise Exception(f"postgres lock {self.lock_name} was not established by this thread")
  299. elif ret[0] == 1:
  300. return True
  301. else:
  302. raise Exception(f"postgres lock {self.lock_name} does not exist")
  303. def __enter__(self):
  304. if isinstance(self.db, PooledPostgresqlDatabase):
  305. self.lock()
  306. return self
  307. def __exit__(self, exc_type, exc_val, exc_tb):
  308. if isinstance(self.db, PooledPostgresqlDatabase):
  309. self.unlock()
  310. def __call__(self, func):
  311. @wraps(func)
  312. def magic(*args, **kwargs):
  313. with self:
  314. return func(*args, **kwargs)
  315. return magic
  316. class MysqlDatabaseLock:
  317. def __init__(self, lock_name, timeout=10, db=None):
  318. self.lock_name = lock_name
  319. self.timeout = int(timeout)
  320. self.db = db if db else DB
  321. @with_retry(max_retries=3, retry_delay=1.0)
  322. def lock(self):
  323. # SQL parameters only support %s format placeholders
  324. cursor = self.db.execute_sql("SELECT GET_LOCK(%s, %s)", (self.lock_name, self.timeout))
  325. ret = cursor.fetchone()
  326. if ret[0] == 0:
  327. raise Exception(f"acquire mysql lock {self.lock_name} timeout")
  328. elif ret[0] == 1:
  329. return True
  330. else:
  331. raise Exception(f"failed to acquire lock {self.lock_name}")
  332. @with_retry(max_retries=3, retry_delay=1.0)
  333. def unlock(self):
  334. cursor = self.db.execute_sql("SELECT RELEASE_LOCK(%s)", (self.lock_name,))
  335. ret = cursor.fetchone()
  336. if ret[0] == 0:
  337. raise Exception(f"mysql lock {self.lock_name} was not established by this thread")
  338. elif ret[0] == 1:
  339. return True
  340. else:
  341. raise Exception(f"mysql lock {self.lock_name} does not exist")
  342. def __enter__(self):
  343. if isinstance(self.db, PooledMySQLDatabase):
  344. self.lock()
  345. return self
  346. def __exit__(self, exc_type, exc_val, exc_tb):
  347. if isinstance(self.db, PooledMySQLDatabase):
  348. self.unlock()
  349. def __call__(self, func):
  350. @wraps(func)
  351. def magic(*args, **kwargs):
  352. with self:
  353. return func(*args, **kwargs)
  354. return magic
  355. class DatabaseLock(Enum):
  356. MYSQL = MysqlDatabaseLock
  357. POSTGRES = PostgresDatabaseLock
  358. DB = BaseDataBase().database_connection
  359. DB.lock = DatabaseLock[settings.DATABASE_TYPE.upper()].value
  360. def close_connection():
  361. try:
  362. if DB:
  363. DB.close_stale(age=30)
  364. except Exception as e:
  365. logging.exception(e)
  366. class DataBaseModel(BaseModel):
  367. class Meta:
  368. database = DB
  369. @DB.connection_context()
  370. @DB.lock("init_database_tables", 60)
  371. def init_database_tables(alter_fields=[]):
  372. members = inspect.getmembers(sys.modules[__name__], inspect.isclass)
  373. table_objs = []
  374. create_failed_list = []
  375. for name, obj in members:
  376. if obj != DataBaseModel and issubclass(obj, DataBaseModel):
  377. table_objs.append(obj)
  378. if not obj.table_exists():
  379. logging.debug(f"start create table {obj.__name__}")
  380. try:
  381. obj.create_table(safe=True)
  382. logging.debug(f"create table success: {obj.__name__}")
  383. except Exception as e:
  384. logging.exception(e)
  385. create_failed_list.append(obj.__name__)
  386. else:
  387. logging.debug(f"table {obj.__name__} already exists, skip creation.")
  388. if create_failed_list:
  389. logging.error(f"create tables failed: {create_failed_list}")
  390. raise Exception(f"create tables failed: {create_failed_list}")
  391. migrate_db()
  392. def fill_db_model_object(model_object, human_model_dict):
  393. for k, v in human_model_dict.items():
  394. attr_name = "%s" % k
  395. if hasattr(model_object.__class__, attr_name):
  396. setattr(model_object, attr_name, v)
  397. return model_object
  398. class User(DataBaseModel, UserMixin):
  399. id = CharField(max_length=32, primary_key=True)
  400. access_token = CharField(max_length=255, null=True, index=True)
  401. nickname = CharField(max_length=100, null=False, help_text="nicky name", index=True)
  402. password = CharField(max_length=255, null=True, help_text="password", index=True)
  403. email = CharField(max_length=255, null=False, help_text="email", index=True)
  404. avatar = TextField(null=True, help_text="avatar base64 string")
  405. language = CharField(max_length=32, null=True, help_text="English|Chinese", default="Chinese" if "zh_CN" in os.getenv("LANG", "") else "English", index=True)
  406. color_schema = CharField(max_length=32, null=True, help_text="Bright|Dark", default="Bright", index=True)
  407. timezone = CharField(max_length=64, null=True, help_text="Timezone", default="UTC+8\tAsia/Shanghai", index=True)
  408. last_login_time = DateTimeField(null=True, index=True)
  409. is_authenticated = CharField(max_length=1, null=False, default="1", index=True)
  410. is_active = CharField(max_length=1, null=False, default="1", index=True)
  411. is_anonymous = CharField(max_length=1, null=False, default="0", index=True)
  412. login_channel = CharField(null=True, help_text="from which user login", index=True)
  413. status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted, 1: validate)", default="1", index=True)
  414. is_superuser = BooleanField(null=True, help_text="is root", default=False, index=True)
  415. def __str__(self):
  416. return self.email
  417. def get_id(self):
  418. jwt = Serializer(secret_key=settings.SECRET_KEY)
  419. return jwt.dumps(str(self.access_token))
  420. class Meta:
  421. db_table = "user"
  422. class Tenant(DataBaseModel):
  423. id = CharField(max_length=32, primary_key=True)
  424. name = CharField(max_length=100, null=True, help_text="Tenant name", index=True)
  425. public_key = CharField(max_length=255, null=True, index=True)
  426. llm_id = CharField(max_length=128, null=False, help_text="default llm ID", index=True)
  427. embd_id = CharField(max_length=128, null=False, help_text="default embedding model ID", index=True)
  428. asr_id = CharField(max_length=128, null=False, help_text="default ASR model ID", index=True)
  429. img2txt_id = CharField(max_length=128, null=False, help_text="default image to text model ID", index=True)
  430. rerank_id = CharField(max_length=128, null=False, help_text="default rerank model ID", index=True)
  431. tts_id = CharField(max_length=256, null=True, help_text="default tts model ID", index=True)
  432. parser_ids = CharField(max_length=256, null=False, help_text="document processors", index=True)
  433. credit = IntegerField(default=512, index=True)
  434. status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted, 1: validate)", default="1", index=True)
  435. class Meta:
  436. db_table = "tenant"
  437. class UserTenant(DataBaseModel):
  438. id = CharField(max_length=32, primary_key=True)
  439. user_id = CharField(max_length=32, null=False, index=True)
  440. tenant_id = CharField(max_length=32, null=False, index=True)
  441. role = CharField(max_length=32, null=False, help_text="UserTenantRole", index=True)
  442. invited_by = CharField(max_length=32, null=False, index=True)
  443. status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted, 1: validate)", default="1", index=True)
  444. class Meta:
  445. db_table = "user_tenant"
  446. class InvitationCode(DataBaseModel):
  447. id = CharField(max_length=32, primary_key=True)
  448. code = CharField(max_length=32, null=False, index=True)
  449. visit_time = DateTimeField(null=True, index=True)
  450. user_id = CharField(max_length=32, null=True, index=True)
  451. tenant_id = CharField(max_length=32, null=True, index=True)
  452. status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted, 1: validate)", default="1", index=True)
  453. class Meta:
  454. db_table = "invitation_code"
  455. class LLMFactories(DataBaseModel):
  456. name = CharField(max_length=128, null=False, help_text="LLM factory name", primary_key=True)
  457. logo = TextField(null=True, help_text="llm logo base64")
  458. tags = CharField(max_length=255, null=False, help_text="LLM, Text Embedding, Image2Text, ASR", index=True)
  459. status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted, 1: validate)", default="1", index=True)
  460. def __str__(self):
  461. return self.name
  462. class Meta:
  463. db_table = "llm_factories"
  464. class LLM(DataBaseModel):
  465. # LLMs dictionary
  466. llm_name = CharField(max_length=128, null=False, help_text="LLM name", index=True)
  467. model_type = CharField(max_length=128, null=False, help_text="LLM, Text Embedding, Image2Text, ASR", index=True)
  468. fid = CharField(max_length=128, null=False, help_text="LLM factory id", index=True)
  469. max_tokens = IntegerField(default=0)
  470. tags = CharField(max_length=255, null=False, help_text="LLM, Text Embedding, Image2Text, Chat, 32k...", index=True)
  471. is_tools = BooleanField(null=False, help_text="support tools", default=False)
  472. status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted, 1: validate)", default="1", index=True)
  473. def __str__(self):
  474. return self.llm_name
  475. class Meta:
  476. primary_key = CompositeKey("fid", "llm_name")
  477. db_table = "llm"
  478. class TenantLLM(DataBaseModel):
  479. tenant_id = CharField(max_length=32, null=False, index=True)
  480. llm_factory = CharField(max_length=128, null=False, help_text="LLM factory name", index=True)
  481. model_type = CharField(max_length=128, null=True, help_text="LLM, Text Embedding, Image2Text, ASR", index=True)
  482. llm_name = CharField(max_length=128, null=True, help_text="LLM name", default="", index=True)
  483. api_key = CharField(max_length=2048, null=True, help_text="API KEY", index=True)
  484. api_base = CharField(max_length=255, null=True, help_text="API Base")
  485. max_tokens = IntegerField(default=8192, index=True)
  486. used_tokens = IntegerField(default=0, index=True)
  487. def __str__(self):
  488. return self.llm_name
  489. class Meta:
  490. db_table = "tenant_llm"
  491. primary_key = CompositeKey("tenant_id", "llm_factory", "llm_name")
  492. class TenantLangfuse(DataBaseModel):
  493. tenant_id = CharField(max_length=32, null=False, primary_key=True)
  494. secret_key = CharField(max_length=2048, null=False, help_text="SECRET KEY", index=True)
  495. public_key = CharField(max_length=2048, null=False, help_text="PUBLIC KEY", index=True)
  496. host = CharField(max_length=128, null=False, help_text="HOST", index=True)
  497. def __str__(self):
  498. return "Langfuse host" + self.host
  499. class Meta:
  500. db_table = "tenant_langfuse"
  501. class Knowledgebase(DataBaseModel):
  502. id = CharField(max_length=32, primary_key=True)
  503. avatar = TextField(null=True, help_text="avatar base64 string")
  504. tenant_id = CharField(max_length=32, null=False, index=True)
  505. name = CharField(max_length=128, null=False, help_text="KB name", index=True)
  506. language = CharField(max_length=32, null=True, default="Chinese" if "zh_CN" in os.getenv("LANG", "") else "English", help_text="English|Chinese", index=True)
  507. description = TextField(null=True, help_text="KB description")
  508. embd_id = CharField(max_length=128, null=False, help_text="default embedding model ID", index=True)
  509. permission = CharField(max_length=16, null=False, help_text="me|team", default="me", index=True)
  510. created_by = CharField(max_length=32, null=False, index=True)
  511. doc_num = IntegerField(default=0, index=True)
  512. token_num = IntegerField(default=0, index=True)
  513. chunk_num = IntegerField(default=0, index=True)
  514. similarity_threshold = FloatField(default=0.2, index=True)
  515. vector_similarity_weight = FloatField(default=0.3, index=True)
  516. parser_id = CharField(max_length=32, null=False, help_text="default parser ID", default=ParserType.NAIVE.value, index=True)
  517. parser_config = JSONField(null=False, default={"pages": [[1, 1000000]]})
  518. pagerank = IntegerField(default=0, index=False)
  519. status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted, 1: validate)", default="1", index=True)
  520. def __str__(self):
  521. return self.name
  522. class Meta:
  523. db_table = "knowledgebase"
  524. class Document(DataBaseModel):
  525. id = CharField(max_length=32, primary_key=True)
  526. thumbnail = TextField(null=True, help_text="thumbnail base64 string")
  527. kb_id = CharField(max_length=256, null=False, index=True)
  528. parser_id = CharField(max_length=32, null=False, help_text="default parser ID", index=True)
  529. parser_config = JSONField(null=False, default={"pages": [[1, 1000000]]})
  530. source_type = CharField(max_length=128, null=False, default="local", help_text="where dose this document come from", index=True)
  531. type = CharField(max_length=32, null=False, help_text="file extension", index=True)
  532. created_by = CharField(max_length=32, null=False, help_text="who created it", index=True)
  533. name = CharField(max_length=255, null=True, help_text="file name", index=True)
  534. location = CharField(max_length=255, null=True, help_text="where dose it store", index=True)
  535. size = IntegerField(default=0, index=True)
  536. token_num = IntegerField(default=0, index=True)
  537. chunk_num = IntegerField(default=0, index=True)
  538. progress = FloatField(default=0, index=True)
  539. progress_msg = TextField(null=True, help_text="process message", default="")
  540. process_begin_at = DateTimeField(null=True, index=True)
  541. process_duration = FloatField(default=0)
  542. meta_fields = JSONField(null=True, default={})
  543. suffix = CharField(max_length=32, null=False, help_text="The real file extension suffix", index=True)
  544. run = CharField(max_length=1, null=True, help_text="start to run processing or cancel.(1: run it; 2: cancel)", default="0", index=True)
  545. status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted, 1: validate)", default="1", index=True)
  546. class Meta:
  547. db_table = "document"
  548. class File(DataBaseModel):
  549. id = CharField(max_length=32, primary_key=True)
  550. parent_id = CharField(max_length=32, null=False, help_text="parent folder id", index=True)
  551. tenant_id = CharField(max_length=32, null=False, help_text="tenant id", index=True)
  552. created_by = CharField(max_length=32, null=False, help_text="who created it", index=True)
  553. name = CharField(max_length=255, null=False, help_text="file name or folder name", index=True)
  554. location = CharField(max_length=255, null=True, help_text="where dose it store", index=True)
  555. size = IntegerField(default=0, index=True)
  556. type = CharField(max_length=32, null=False, help_text="file extension", index=True)
  557. source_type = CharField(max_length=128, null=False, default="", help_text="where dose this document come from", index=True)
  558. class Meta:
  559. db_table = "file"
  560. class File2Document(DataBaseModel):
  561. id = CharField(max_length=32, primary_key=True)
  562. file_id = CharField(max_length=32, null=True, help_text="file id", index=True)
  563. document_id = CharField(max_length=32, null=True, help_text="document id", index=True)
  564. class Meta:
  565. db_table = "file2document"
  566. class Task(DataBaseModel):
  567. id = CharField(max_length=32, primary_key=True)
  568. doc_id = CharField(max_length=32, null=False, index=True)
  569. from_page = IntegerField(default=0)
  570. to_page = IntegerField(default=100000000)
  571. task_type = CharField(max_length=32, null=False, default="")
  572. priority = IntegerField(default=0)
  573. begin_at = DateTimeField(null=True, index=True)
  574. process_duration = FloatField(default=0)
  575. progress = FloatField(default=0, index=True)
  576. progress_msg = TextField(null=True, help_text="process message", default="")
  577. retry_count = IntegerField(default=0)
  578. digest = TextField(null=True, help_text="task digest", default="")
  579. chunk_ids = LongTextField(null=True, help_text="chunk ids", default="")
  580. class Dialog(DataBaseModel):
  581. id = CharField(max_length=32, primary_key=True)
  582. tenant_id = CharField(max_length=32, null=False, index=True)
  583. name = CharField(max_length=255, null=True, help_text="dialog application name", index=True)
  584. description = TextField(null=True, help_text="Dialog description")
  585. icon = TextField(null=True, help_text="icon base64 string")
  586. language = CharField(max_length=32, null=True, default="Chinese" if "zh_CN" in os.getenv("LANG", "") else "English", help_text="English|Chinese", index=True)
  587. llm_id = CharField(max_length=128, null=False, help_text="default llm ID")
  588. llm_setting = JSONField(null=False, default={"temperature": 0.1, "top_p": 0.3, "frequency_penalty": 0.7, "presence_penalty": 0.4, "max_tokens": 512})
  589. prompt_type = CharField(max_length=16, null=False, default="simple", help_text="simple|advanced", index=True)
  590. prompt_config = JSONField(
  591. null=False,
  592. default={"system": "", "prologue": "Hi! I'm your assistant. What can I do for you?", "parameters": [], "empty_response": "Sorry! No relevant content was found in the knowledge base!"},
  593. )
  594. meta_data_filter = JSONField(null=True, default={})
  595. similarity_threshold = FloatField(default=0.2)
  596. vector_similarity_weight = FloatField(default=0.3)
  597. top_n = IntegerField(default=6)
  598. top_k = IntegerField(default=1024)
  599. do_refer = CharField(max_length=1, null=False, default="1", help_text="it needs to insert reference index into answer or not")
  600. rerank_id = CharField(max_length=128, null=False, help_text="default rerank model ID")
  601. kb_ids = JSONField(null=False, default=[])
  602. status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted, 1: validate)", default="1", index=True)
  603. class Meta:
  604. db_table = "dialog"
  605. class Conversation(DataBaseModel):
  606. id = CharField(max_length=32, primary_key=True)
  607. dialog_id = CharField(max_length=32, null=False, index=True)
  608. name = CharField(max_length=255, null=True, help_text="converastion name", index=True)
  609. message = JSONField(null=True)
  610. reference = JSONField(null=True, default=[])
  611. user_id = CharField(max_length=255, null=True, help_text="user_id", index=True)
  612. class Meta:
  613. db_table = "conversation"
  614. class APIToken(DataBaseModel):
  615. tenant_id = CharField(max_length=32, null=False, index=True)
  616. token = CharField(max_length=255, null=False, index=True)
  617. dialog_id = CharField(max_length=32, null=True, index=True)
  618. source = CharField(max_length=16, null=True, help_text="none|agent|dialog", index=True)
  619. beta = CharField(max_length=255, null=True, index=True)
  620. class Meta:
  621. db_table = "api_token"
  622. primary_key = CompositeKey("tenant_id", "token")
  623. class API4Conversation(DataBaseModel):
  624. id = CharField(max_length=32, primary_key=True)
  625. dialog_id = CharField(max_length=32, null=False, index=True)
  626. user_id = CharField(max_length=255, null=False, help_text="user_id", index=True)
  627. message = JSONField(null=True)
  628. reference = JSONField(null=True, default=[])
  629. tokens = IntegerField(default=0)
  630. source = CharField(max_length=16, null=True, help_text="none|agent|dialog", index=True)
  631. dsl = JSONField(null=True, default={})
  632. duration = FloatField(default=0, index=True)
  633. round = IntegerField(default=0, index=True)
  634. thumb_up = IntegerField(default=0, index=True)
  635. errors = TextField(null=True, help_text="errors")
  636. class Meta:
  637. db_table = "api_4_conversation"
  638. class UserCanvas(DataBaseModel):
  639. id = CharField(max_length=32, primary_key=True)
  640. avatar = TextField(null=True, help_text="avatar base64 string")
  641. user_id = CharField(max_length=255, null=False, help_text="user_id", index=True)
  642. title = CharField(max_length=255, null=True, help_text="Canvas title")
  643. permission = CharField(max_length=16, null=False, help_text="me|team", default="me", index=True)
  644. description = TextField(null=True, help_text="Canvas description")
  645. canvas_type = CharField(max_length=32, null=True, help_text="Canvas type", index=True)
  646. dsl = JSONField(null=True, default={})
  647. class Meta:
  648. db_table = "user_canvas"
  649. class CanvasTemplate(DataBaseModel):
  650. id = CharField(max_length=32, primary_key=True)
  651. avatar = TextField(null=True, help_text="avatar base64 string")
  652. title = CharField(max_length=255, null=True, help_text="Canvas title")
  653. description = TextField(null=True, help_text="Canvas description")
  654. canvas_type = CharField(max_length=32, null=True, help_text="Canvas type", index=True)
  655. dsl = JSONField(null=True, default={})
  656. class Meta:
  657. db_table = "canvas_template"
  658. class UserCanvasVersion(DataBaseModel):
  659. id = CharField(max_length=32, primary_key=True)
  660. user_canvas_id = CharField(max_length=255, null=False, help_text="user_canvas_id", index=True)
  661. title = CharField(max_length=255, null=True, help_text="Canvas title")
  662. description = TextField(null=True, help_text="Canvas description")
  663. dsl = JSONField(null=True, default={})
  664. class Meta:
  665. db_table = "user_canvas_version"
  666. class MCPServer(DataBaseModel):
  667. id = CharField(max_length=32, primary_key=True)
  668. name = CharField(max_length=255, null=False, help_text="MCP Server name")
  669. tenant_id = CharField(max_length=32, null=False, index=True)
  670. url = CharField(max_length=2048, null=False, help_text="MCP Server URL")
  671. server_type = CharField(max_length=32, null=False, help_text="MCP Server type")
  672. description = TextField(null=True, help_text="MCP Server description")
  673. variables = JSONField(null=True, default=dict, help_text="MCP Server variables")
  674. headers = JSONField(null=True, default=dict, help_text="MCP Server additional request headers")
  675. class Meta:
  676. db_table = "mcp_server"
  677. class Search(DataBaseModel):
  678. id = CharField(max_length=32, primary_key=True)
  679. avatar = TextField(null=True, help_text="avatar base64 string")
  680. tenant_id = CharField(max_length=32, null=False, index=True)
  681. name = CharField(max_length=128, null=False, help_text="Search name", index=True)
  682. description = TextField(null=True, help_text="KB description")
  683. created_by = CharField(max_length=32, null=False, index=True)
  684. search_config = JSONField(
  685. null=False,
  686. default={
  687. "kb_ids": [],
  688. "doc_ids": [],
  689. "similarity_threshold": 0.0,
  690. "vector_similarity_weight": 0.3,
  691. "use_kg": False,
  692. # rerank settings
  693. "rerank_id": "",
  694. "top_k": 1024,
  695. # chat settings
  696. "summary": False,
  697. "chat_id": "",
  698. "llm_setting": {
  699. "temperature": 0.1,
  700. "top_p": 0.3,
  701. "frequency_penalty": 0.7,
  702. "presence_penalty": 0.4,
  703. },
  704. "chat_settingcross_languages": [],
  705. "highlight": False,
  706. "keyword": False,
  707. "web_search": False,
  708. "related_search": False,
  709. "query_mindmap": False,
  710. },
  711. )
  712. status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted, 1: validate)", default="1", index=True)
  713. def __str__(self):
  714. return self.name
  715. class Meta:
  716. db_table = "search"
  717. def migrate_db():
  718. logging.disable(logging.ERROR)
  719. migrator = DatabaseMigrator[settings.DATABASE_TYPE.upper()].value(DB)
  720. try:
  721. migrate(migrator.add_column("file", "source_type", CharField(max_length=128, null=False, default="", help_text="where dose this document come from", index=True)))
  722. except Exception:
  723. pass
  724. try:
  725. migrate(migrator.add_column("tenant", "rerank_id", CharField(max_length=128, null=False, default="BAAI/bge-reranker-v2-m3", help_text="default rerank model ID")))
  726. except Exception:
  727. pass
  728. try:
  729. migrate(migrator.add_column("dialog", "rerank_id", CharField(max_length=128, null=False, default="", help_text="default rerank model ID")))
  730. except Exception:
  731. pass
  732. try:
  733. migrate(migrator.add_column("dialog", "top_k", IntegerField(default=1024)))
  734. except Exception:
  735. pass
  736. try:
  737. migrate(migrator.alter_column_type("tenant_llm", "api_key", CharField(max_length=2048, null=True, help_text="API KEY", index=True)))
  738. except Exception:
  739. pass
  740. try:
  741. migrate(migrator.add_column("api_token", "source", CharField(max_length=16, null=True, help_text="none|agent|dialog", index=True)))
  742. except Exception:
  743. pass
  744. try:
  745. migrate(migrator.add_column("tenant", "tts_id", CharField(max_length=256, null=True, help_text="default tts model ID", index=True)))
  746. except Exception:
  747. pass
  748. try:
  749. migrate(migrator.add_column("api_4_conversation", "source", CharField(max_length=16, null=True, help_text="none|agent|dialog", index=True)))
  750. except Exception:
  751. pass
  752. try:
  753. migrate(migrator.add_column("task", "retry_count", IntegerField(default=0)))
  754. except Exception:
  755. pass
  756. try:
  757. migrate(migrator.alter_column_type("api_token", "dialog_id", CharField(max_length=32, null=True, index=True)))
  758. except Exception:
  759. pass
  760. try:
  761. migrate(migrator.add_column("tenant_llm", "max_tokens", IntegerField(default=8192, index=True)))
  762. except Exception:
  763. pass
  764. try:
  765. migrate(migrator.add_column("api_4_conversation", "dsl", JSONField(null=True, default={})))
  766. except Exception:
  767. pass
  768. try:
  769. migrate(migrator.add_column("knowledgebase", "pagerank", IntegerField(default=0, index=False)))
  770. except Exception:
  771. pass
  772. try:
  773. migrate(migrator.add_column("api_token", "beta", CharField(max_length=255, null=True, index=True)))
  774. except Exception:
  775. pass
  776. try:
  777. migrate(migrator.add_column("task", "digest", TextField(null=True, help_text="task digest", default="")))
  778. except Exception:
  779. pass
  780. try:
  781. migrate(migrator.add_column("task", "chunk_ids", LongTextField(null=True, help_text="chunk ids", default="")))
  782. except Exception:
  783. pass
  784. try:
  785. migrate(migrator.add_column("conversation", "user_id", CharField(max_length=255, null=True, help_text="user_id", index=True)))
  786. except Exception:
  787. pass
  788. try:
  789. migrate(migrator.add_column("document", "meta_fields", JSONField(null=True, default={})))
  790. except Exception:
  791. pass
  792. try:
  793. migrate(migrator.add_column("task", "task_type", CharField(max_length=32, null=False, default="")))
  794. except Exception:
  795. pass
  796. try:
  797. migrate(migrator.add_column("task", "priority", IntegerField(default=0)))
  798. except Exception:
  799. pass
  800. try:
  801. migrate(migrator.add_column("user_canvas", "permission", CharField(max_length=16, null=False, help_text="me|team", default="me", index=True)))
  802. except Exception:
  803. pass
  804. try:
  805. migrate(migrator.add_column("llm", "is_tools", BooleanField(null=False, help_text="support tools", default=False)))
  806. except Exception:
  807. pass
  808. try:
  809. migrate(migrator.add_column("mcp_server", "variables", JSONField(null=True, help_text="MCP Server variables", default=dict)))
  810. except Exception:
  811. pass
  812. try:
  813. migrate(migrator.rename_column("task", "process_duation", "process_duration"))
  814. except Exception:
  815. pass
  816. try:
  817. migrate(migrator.rename_column("document", "process_duation", "process_duration"))
  818. except Exception:
  819. pass
  820. try:
  821. migrate(migrator.add_column("document", "suffix", CharField(max_length=32, null=False, default="", help_text="The real file extension suffix", index=True)))
  822. except Exception:
  823. pass
  824. try:
  825. migrate(migrator.add_column("api_4_conversation", "errors", TextField(null=True, help_text="errors")))
  826. except Exception:
  827. pass
  828. try:
  829. migrate(migrator.add_column("dialog", "meta_data_filter", JSONField(null=True, default={})))
  830. except Exception:
  831. pass
  832. logging.disable(logging.NOTSET)