Ви не можете вибрати більше 25 тем Теми мають розпочинатися з літери або цифри, можуть містити дефіси (-) і не повинні перевищувати 35 символів.


  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import inspect
  17. import logging
  18. import operator
  19. import os
  20. import sys
  21. import typing
  22. from enum import Enum
  23. from functools import wraps
  24. from flask_login import UserMixin
  25. from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
  26. from peewee import BigIntegerField, BooleanField, CharField, CompositeKey, DateTimeField, Field, FloatField, IntegerField, Metadata, Model, TextField
  27. from playhouse.migrate import MySQLMigrator, PostgresqlMigrator, migrate
  28. from playhouse.pool import PooledMySQLDatabase, PooledPostgresqlDatabase
  29. from api import settings, utils
  30. from api.db import ParserType, SerializedType
  31. def singleton(cls, *args, **kw):
  32. instances = {}
  33. def _singleton():
  34. key = str(cls) + str(os.getpid())
  35. if key not in instances:
  36. instances[key] = cls(*args, **kw)
  37. return instances[key]
  38. return _singleton
  39. CONTINUOUS_FIELD_TYPE = {IntegerField, FloatField, DateTimeField}
  40. AUTO_DATE_TIMESTAMP_FIELD_PREFIX = {"create", "start", "end", "update", "read_access", "write_access"}
  41. class TextFieldType(Enum):
  42. MYSQL = "LONGTEXT"
  43. POSTGRES = "TEXT"
  44. class LongTextField(TextField):
  45. field_type = TextFieldType[settings.DATABASE_TYPE.upper()].value
  46. class JSONField(LongTextField):
  47. default_value = {}
  48. def __init__(self, object_hook=None, object_pairs_hook=None, **kwargs):
  49. self._object_hook = object_hook
  50. self._object_pairs_hook = object_pairs_hook
  51. super().__init__(**kwargs)
  52. def db_value(self, value):
  53. if value is None:
  54. value = self.default_value
  55. return utils.json_dumps(value)
  56. def python_value(self, value):
  57. if not value:
  58. return self.default_value
  59. return utils.json_loads(value, object_hook=self._object_hook, object_pairs_hook=self._object_pairs_hook)
  60. class ListField(JSONField):
  61. default_value = []
  62. class SerializedField(LongTextField):
  63. def __init__(self, serialized_type=SerializedType.PICKLE, object_hook=None, object_pairs_hook=None, **kwargs):
  64. self._serialized_type = serialized_type
  65. self._object_hook = object_hook
  66. self._object_pairs_hook = object_pairs_hook
  67. super().__init__(**kwargs)
  68. def db_value(self, value):
  69. if self._serialized_type == SerializedType.PICKLE:
  70. return utils.serialize_b64(value, to_str=True)
  71. elif self._serialized_type == SerializedType.JSON:
  72. if value is None:
  73. return None
  74. return utils.json_dumps(value, with_type=True)
  75. else:
  76. raise ValueError(f"the serialized type {self._serialized_type} is not supported")
  77. def python_value(self, value):
  78. if self._serialized_type == SerializedType.PICKLE:
  79. return utils.deserialize_b64(value)
  80. elif self._serialized_type == SerializedType.JSON:
  81. if value is None:
  82. return {}
  83. return utils.json_loads(value, object_hook=self._object_hook, object_pairs_hook=self._object_pairs_hook)
  84. else:
  85. raise ValueError(f"the serialized type {self._serialized_type} is not supported")
  86. def is_continuous_field(cls: typing.Type) -> bool:
  87. if cls in CONTINUOUS_FIELD_TYPE:
  88. return True
  89. for p in cls.__bases__:
  90. if p in CONTINUOUS_FIELD_TYPE:
  91. return True
  92. elif p is not Field and p is not object:
  93. if is_continuous_field(p):
  94. return True
  95. else:
  96. return False
  97. def auto_date_timestamp_field():
  98. return {f"{f}_time" for f in AUTO_DATE_TIMESTAMP_FIELD_PREFIX}
  99. def auto_date_timestamp_db_field():
  100. return {f"f_{f}_time" for f in AUTO_DATE_TIMESTAMP_FIELD_PREFIX}
  101. def remove_field_name_prefix(field_name):
  102. return field_name[2:] if field_name.startswith("f_") else field_name
  103. class BaseModel(Model):
  104. create_time = BigIntegerField(null=True, index=True)
  105. create_date = DateTimeField(null=True, index=True)
  106. update_time = BigIntegerField(null=True, index=True)
  107. update_date = DateTimeField(null=True, index=True)
  108. def to_json(self):
  109. # This function is obsolete
  110. return self.to_dict()
  111. def to_dict(self):
  112. return self.__dict__["__data__"]
  113. def to_human_model_dict(self, only_primary_with: list = None):
  114. model_dict = self.__dict__["__data__"]
  115. if not only_primary_with:
  116. return {remove_field_name_prefix(k): v for k, v in model_dict.items()}
  117. human_model_dict = {}
  118. for k in self._meta.primary_key.field_names:
  119. human_model_dict[remove_field_name_prefix(k)] = model_dict[k]
  120. for k in only_primary_with:
  121. human_model_dict[k] = model_dict[f"f_{k}"]
  122. return human_model_dict
  123. @property
  124. def meta(self) -> Metadata:
  125. return self._meta
  126. @classmethod
  127. def get_primary_keys_name(cls):
  128. return cls._meta.primary_key.field_names if isinstance(cls._meta.primary_key, CompositeKey) else [cls._meta.primary_key.name]
  129. @classmethod
  130. def getter_by(cls, attr):
  131. return operator.attrgetter(attr)(cls)
  132. @classmethod
  133. def query(cls, reverse=None, order_by=None, **kwargs):
  134. filters = []
  135. for f_n, f_v in kwargs.items():
  136. attr_name = "%s" % f_n
  137. if not hasattr(cls, attr_name) or f_v is None:
  138. continue
  139. if type(f_v) in {list, set}:
  140. f_v = list(f_v)
  141. if is_continuous_field(type(getattr(cls, attr_name))):
  142. if len(f_v) == 2:
  143. for i, v in enumerate(f_v):
  144. if isinstance(v, str) and f_n in auto_date_timestamp_field():
  145. # time type: %Y-%m-%d %H:%M:%S
  146. f_v[i] = utils.date_string_to_timestamp(v)
  147. lt_value = f_v[0]
  148. gt_value = f_v[1]
  149. if lt_value is not None and gt_value is not None:
  150. filters.append(cls.getter_by(attr_name).between(lt_value, gt_value))
  151. elif lt_value is not None:
  152. filters.append(operator.attrgetter(attr_name)(cls) >= lt_value)
  153. elif gt_value is not None:
  154. filters.append(operator.attrgetter(attr_name)(cls) <= gt_value)
  155. else:
  156. filters.append(operator.attrgetter(attr_name)(cls) << f_v)
  157. else:
  158. filters.append(operator.attrgetter(attr_name)(cls) == f_v)
  159. if filters:
  160. query_records = cls.select().where(*filters)
  161. if reverse is not None:
  162. if not order_by or not hasattr(cls, f"{order_by}"):
  163. order_by = "create_time"
  164. if reverse is True:
  165. query_records = query_records.order_by(cls.getter_by(f"{order_by}").desc())
  166. elif reverse is False:
  167. query_records = query_records.order_by(cls.getter_by(f"{order_by}").asc())
  168. return [query_record for query_record in query_records]
  169. else:
  170. return []
  171. @classmethod
  172. def insert(cls, __data=None, **insert):
  173. if isinstance(__data, dict) and __data:
  174. __data[cls._meta.combined["create_time"]] = utils.current_timestamp()
  175. if insert:
  176. insert["create_time"] = utils.current_timestamp()
  177. return super().insert(__data, **insert)
  178. # update and insert will call this method
  179. @classmethod
  180. def _normalize_data(cls, data, kwargs):
  181. normalized = super()._normalize_data(data, kwargs)
  182. if not normalized:
  183. return {}
  184. normalized[cls._meta.combined["update_time"]] = utils.current_timestamp()
  185. for f_n in AUTO_DATE_TIMESTAMP_FIELD_PREFIX:
  186. if {f"{f_n}_time", f"{f_n}_date"}.issubset(cls._meta.combined.keys()) and cls._meta.combined[f"{f_n}_time"] in normalized and normalized[cls._meta.combined[f"{f_n}_time"]] is not None:
  187. normalized[cls._meta.combined[f"{f_n}_date"]] = utils.timestamp_to_date(normalized[cls._meta.combined[f"{f_n}_time"]])
  188. return normalized
  189. class JsonSerializedField(SerializedField):
  190. def __init__(self, object_hook=utils.from_dict_hook, object_pairs_hook=None, **kwargs):
  191. super(JsonSerializedField, self).__init__(serialized_type=SerializedType.JSON, object_hook=object_hook, object_pairs_hook=object_pairs_hook, **kwargs)
  192. class PooledDatabase(Enum):
  193. MYSQL = PooledMySQLDatabase
  194. POSTGRES = PooledPostgresqlDatabase
  195. class DatabaseMigrator(Enum):
  196. MYSQL = MySQLMigrator
  197. POSTGRES = PostgresqlMigrator
  198. @singleton
  199. class BaseDataBase:
  200. def __init__(self):
  201. database_config = settings.DATABASE.copy()
  202. db_name = database_config.pop("name")
  203. self.database_connection = PooledDatabase[settings.DATABASE_TYPE.upper()].value(db_name, **database_config)
  204. logging.info("init database on cluster mode successfully")
  205. class PostgresDatabaseLock:
  206. def __init__(self, lock_name, timeout=10, db=None):
  207. self.lock_name = lock_name
  208. self.timeout = int(timeout)
  209. self.db = db if db else DB
  210. def lock(self):
  211. cursor = self.db.execute_sql("SELECT pg_try_advisory_lock(%s)", self.timeout)
  212. ret = cursor.fetchone()
  213. if ret[0] == 0:
  214. raise Exception(f"acquire postgres lock {self.lock_name} timeout")
  215. elif ret[0] == 1:
  216. return True
  217. else:
  218. raise Exception(f"failed to acquire lock {self.lock_name}")
  219. def unlock(self):
  220. cursor = self.db.execute_sql("SELECT pg_advisory_unlock(%s)", self.timeout)
  221. ret = cursor.fetchone()
  222. if ret[0] == 0:
  223. raise Exception(f"postgres lock {self.lock_name} was not established by this thread")
  224. elif ret[0] == 1:
  225. return True
  226. else:
  227. raise Exception(f"postgres lock {self.lock_name} does not exist")
  228. def __enter__(self):
  229. if isinstance(self.db, PostgresDatabaseLock):
  230. self.lock()
  231. return self
  232. def __exit__(self, exc_type, exc_val, exc_tb):
  233. if isinstance(self.db, PostgresDatabaseLock):
  234. self.unlock()
  235. def __call__(self, func):
  236. @wraps(func)
  237. def magic(*args, **kwargs):
  238. with self:
  239. return func(*args, **kwargs)
  240. return magic
  241. class MysqlDatabaseLock:
  242. def __init__(self, lock_name, timeout=10, db=None):
  243. self.lock_name = lock_name
  244. self.timeout = int(timeout)
  245. self.db = db if db else DB
  246. def lock(self):
  247. # SQL parameters only support %s format placeholders
  248. cursor = self.db.execute_sql("SELECT GET_LOCK(%s, %s)", (self.lock_name, self.timeout))
  249. ret = cursor.fetchone()
  250. if ret[0] == 0:
  251. raise Exception(f"acquire mysql lock {self.lock_name} timeout")
  252. elif ret[0] == 1:
  253. return True
  254. else:
  255. raise Exception(f"failed to acquire lock {self.lock_name}")
  256. def unlock(self):
  257. cursor = self.db.execute_sql("SELECT RELEASE_LOCK(%s)", (self.lock_name,))
  258. ret = cursor.fetchone()
  259. if ret[0] == 0:
  260. raise Exception(f"mysql lock {self.lock_name} was not established by this thread")
  261. elif ret[0] == 1:
  262. return True
  263. else:
  264. raise Exception(f"mysql lock {self.lock_name} does not exist")
  265. def __enter__(self):
  266. if isinstance(self.db, PooledMySQLDatabase):
  267. self.lock()
  268. return self
  269. def __exit__(self, exc_type, exc_val, exc_tb):
  270. if isinstance(self.db, PooledMySQLDatabase):
  271. self.unlock()
  272. def __call__(self, func):
  273. @wraps(func)
  274. def magic(*args, **kwargs):
  275. with self:
  276. return func(*args, **kwargs)
  277. return magic
  278. class DatabaseLock(Enum):
  279. MYSQL = MysqlDatabaseLock
  280. POSTGRES = PostgresDatabaseLock
  281. DB = BaseDataBase().database_connection
  282. DB.lock = DatabaseLock[settings.DATABASE_TYPE.upper()].value
  283. def close_connection():
  284. try:
  285. if DB:
  286. DB.close_stale(age=30)
  287. except Exception as e:
  288. logging.exception(e)
  289. class DataBaseModel(BaseModel):
  290. class Meta:
  291. database = DB
  292. @DB.connection_context()
  293. def init_database_tables(alter_fields=[]):
  294. members = inspect.getmembers(sys.modules[__name__], inspect.isclass)
  295. table_objs = []
  296. create_failed_list = []
  297. for name, obj in members:
  298. if obj != DataBaseModel and issubclass(obj, DataBaseModel):
  299. table_objs.append(obj)
  300. logging.debug(f"start create table {obj.__name__}")
  301. try:
  302. obj.create_table()
  303. logging.debug(f"create table success: {obj.__name__}")
  304. except Exception as e:
  305. logging.exception(e)
  306. create_failed_list.append(obj.__name__)
  307. if create_failed_list:
  308. logging.error(f"create tables failed: {create_failed_list}")
  309. raise Exception(f"create tables failed: {create_failed_list}")
  310. migrate_db()
  311. def fill_db_model_object(model_object, human_model_dict):
  312. for k, v in human_model_dict.items():
  313. attr_name = "%s" % k
  314. if hasattr(model_object.__class__, attr_name):
  315. setattr(model_object, attr_name, v)
  316. return model_object
  317. class User(DataBaseModel, UserMixin):
  318. id = CharField(max_length=32, primary_key=True)
  319. access_token = CharField(max_length=255, null=True, index=True)
  320. nickname = CharField(max_length=100, null=False, help_text="nicky name", index=True)
  321. password = CharField(max_length=255, null=True, help_text="password", index=True)
  322. email = CharField(max_length=255, null=False, help_text="email", index=True)
  323. avatar = TextField(null=True, help_text="avatar base64 string")
  324. language = CharField(max_length=32, null=True, help_text="English|Chinese", default="Chinese" if "zh_CN" in os.getenv("LANG", "") else "English", index=True)
  325. color_schema = CharField(max_length=32, null=True, help_text="Bright|Dark", default="Bright", index=True)
  326. timezone = CharField(max_length=64, null=True, help_text="Timezone", default="UTC+8\tAsia/Shanghai", index=True)
  327. last_login_time = DateTimeField(null=True, index=True)
  328. is_authenticated = CharField(max_length=1, null=False, default="1", index=True)
  329. is_active = CharField(max_length=1, null=False, default="1", index=True)
  330. is_anonymous = CharField(max_length=1, null=False, default="0", index=True)
  331. login_channel = CharField(null=True, help_text="from which user login", index=True)
  332. status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted, 1: validate)", default="1", index=True)
  333. is_superuser = BooleanField(null=True, help_text="is root", default=False, index=True)
  334. def __str__(self):
  335. return self.email
  336. def get_id(self):
  337. jwt = Serializer(secret_key=settings.SECRET_KEY)
  338. return jwt.dumps(str(self.access_token))
  339. class Meta:
  340. db_table = "user"
  341. class Tenant(DataBaseModel):
  342. id = CharField(max_length=32, primary_key=True)
  343. name = CharField(max_length=100, null=True, help_text="Tenant name", index=True)
  344. public_key = CharField(max_length=255, null=True, index=True)
  345. llm_id = CharField(max_length=128, null=False, help_text="default llm ID", index=True)
  346. embd_id = CharField(max_length=128, null=False, help_text="default embedding model ID", index=True)
  347. asr_id = CharField(max_length=128, null=False, help_text="default ASR model ID", index=True)
  348. img2txt_id = CharField(max_length=128, null=False, help_text="default image to text model ID", index=True)
  349. rerank_id = CharField(max_length=128, null=False, help_text="default rerank model ID", index=True)
  350. tts_id = CharField(max_length=256, null=True, help_text="default tts model ID", index=True)
  351. parser_ids = CharField(max_length=256, null=False, help_text="document processors", index=True)
  352. credit = IntegerField(default=512, index=True)
  353. status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted, 1: validate)", default="1", index=True)
  354. class Meta:
  355. db_table = "tenant"
  356. class UserTenant(DataBaseModel):
  357. id = CharField(max_length=32, primary_key=True)
  358. user_id = CharField(max_length=32, null=False, index=True)
  359. tenant_id = CharField(max_length=32, null=False, index=True)
  360. role = CharField(max_length=32, null=False, help_text="UserTenantRole", index=True)
  361. invited_by = CharField(max_length=32, null=False, index=True)
  362. status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted, 1: validate)", default="1", index=True)
  363. class Meta:
  364. db_table = "user_tenant"
  365. class InvitationCode(DataBaseModel):
  366. id = CharField(max_length=32, primary_key=True)
  367. code = CharField(max_length=32, null=False, index=True)
  368. visit_time = DateTimeField(null=True, index=True)
  369. user_id = CharField(max_length=32, null=True, index=True)
  370. tenant_id = CharField(max_length=32, null=True, index=True)
  371. status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted, 1: validate)", default="1", index=True)
  372. class Meta:
  373. db_table = "invitation_code"
  374. class LLMFactories(DataBaseModel):
  375. name = CharField(max_length=128, null=False, help_text="LLM factory name", primary_key=True)
  376. logo = TextField(null=True, help_text="llm logo base64")
  377. tags = CharField(max_length=255, null=False, help_text="LLM, Text Embedding, Image2Text, ASR", index=True)
  378. status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted, 1: validate)", default="1", index=True)
  379. def __str__(self):
  380. return self.name
  381. class Meta:
  382. db_table = "llm_factories"
  383. class LLM(DataBaseModel):
  384. # LLMs dictionary
  385. llm_name = CharField(max_length=128, null=False, help_text="LLM name", index=True)
  386. model_type = CharField(max_length=128, null=False, help_text="LLM, Text Embedding, Image2Text, ASR", index=True)
  387. fid = CharField(max_length=128, null=False, help_text="LLM factory id", index=True)
  388. max_tokens = IntegerField(default=0)
  389. tags = CharField(max_length=255, null=False, help_text="LLM, Text Embedding, Image2Text, Chat, 32k...", index=True)
  390. status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted, 1: validate)", default="1", index=True)
  391. def __str__(self):
  392. return self.llm_name
  393. class Meta:
  394. primary_key = CompositeKey("fid", "llm_name")
  395. db_table = "llm"
  396. class TenantLLM(DataBaseModel):
  397. tenant_id = CharField(max_length=32, null=False, index=True)
  398. llm_factory = CharField(max_length=128, null=False, help_text="LLM factory name", index=True)
  399. model_type = CharField(max_length=128, null=True, help_text="LLM, Text Embedding, Image2Text, ASR", index=True)
  400. llm_name = CharField(max_length=128, null=True, help_text="LLM name", default="", index=True)
  401. api_key = CharField(max_length=2048, null=True, help_text="API KEY", index=True)
  402. api_base = CharField(max_length=255, null=True, help_text="API Base")
  403. max_tokens = IntegerField(default=8192, index=True)
  404. used_tokens = IntegerField(default=0, index=True)
  405. def __str__(self):
  406. return self.llm_name
  407. class Meta:
  408. db_table = "tenant_llm"
  409. primary_key = CompositeKey("tenant_id", "llm_factory", "llm_name")
  410. class TenantLangfuse(DataBaseModel):
  411. tenant_id = CharField(max_length=32, null=False, primary_key=True)
  412. secret_key = CharField(max_length=2048, null=False, help_text="SECRET KEY", index=True)
  413. public_key = CharField(max_length=2048, null=False, help_text="PUBLIC KEY", index=True)
  414. host = CharField(max_length=128, null=False, help_text="host", index=True)
  415. # max_tokens = IntegerField(default=8192, index=True)
  416. # used_tokens = IntegerField(default=0, index=True)
  417. def __str__(self):
  418. return "Langfuse host" + self.host
  419. class Meta:
  420. db_table = "tenant_langfuse"
  421. class Knowledgebase(DataBaseModel):
  422. id = CharField(max_length=32, primary_key=True)
  423. avatar = TextField(null=True, help_text="avatar base64 string")
  424. tenant_id = CharField(max_length=32, null=False, index=True)
  425. name = CharField(max_length=128, null=False, help_text="KB name", index=True)
  426. language = CharField(max_length=32, null=True, default="Chinese" if "zh_CN" in os.getenv("LANG", "") else "English", help_text="English|Chinese", index=True)
  427. description = TextField(null=True, help_text="KB description")
  428. embd_id = CharField(max_length=128, null=False, help_text="default embedding model ID", index=True)
  429. permission = CharField(max_length=16, null=False, help_text="me|team", default="me", index=True)
  430. created_by = CharField(max_length=32, null=False, index=True)
  431. doc_num = IntegerField(default=0, index=True)
  432. token_num = IntegerField(default=0, index=True)
  433. chunk_num = IntegerField(default=0, index=True)
  434. similarity_threshold = FloatField(default=0.2, index=True)
  435. vector_similarity_weight = FloatField(default=0.3, index=True)
  436. parser_id = CharField(max_length=32, null=False, help_text="default parser ID", default=ParserType.NAIVE.value, index=True)
  437. parser_config = JSONField(null=False, default={"pages": [[1, 1000000]]})
  438. pagerank = IntegerField(default=0, index=False)
  439. status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted, 1: validate)", default="1", index=True)
  440. def __str__(self):
  441. return self.name
  442. class Meta:
  443. db_table = "knowledgebase"
  444. class Document(DataBaseModel):
  445. id = CharField(max_length=32, primary_key=True)
  446. thumbnail = TextField(null=True, help_text="thumbnail base64 string")
  447. kb_id = CharField(max_length=256, null=False, index=True)
  448. parser_id = CharField(max_length=32, null=False, help_text="default parser ID", index=True)
  449. parser_config = JSONField(null=False, default={"pages": [[1, 1000000]]})
  450. source_type = CharField(max_length=128, null=False, default="local", help_text="where dose this document come from", index=True)
  451. type = CharField(max_length=32, null=False, help_text="file extension", index=True)
  452. created_by = CharField(max_length=32, null=False, help_text="who created it", index=True)
  453. name = CharField(max_length=255, null=True, help_text="file name", index=True)
  454. location = CharField(max_length=255, null=True, help_text="where dose it store", index=True)
  455. size = IntegerField(default=0, index=True)
  456. token_num = IntegerField(default=0, index=True)
  457. chunk_num = IntegerField(default=0, index=True)
  458. progress = FloatField(default=0, index=True)
  459. progress_msg = TextField(null=True, help_text="process message", default="")
  460. process_begin_at = DateTimeField(null=True, index=True)
  461. process_duation = FloatField(default=0)
  462. meta_fields = JSONField(null=True, default={})
  463. run = CharField(max_length=1, null=True, help_text="start to run processing or cancel.(1: run it; 2: cancel)", default="0", index=True)
  464. status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted, 1: validate)", default="1", index=True)
  465. class Meta:
  466. db_table = "document"
  467. class File(DataBaseModel):
  468. id = CharField(max_length=32, primary_key=True)
  469. parent_id = CharField(max_length=32, null=False, help_text="parent folder id", index=True)
  470. tenant_id = CharField(max_length=32, null=False, help_text="tenant id", index=True)
  471. created_by = CharField(max_length=32, null=False, help_text="who created it", index=True)
  472. name = CharField(max_length=255, null=False, help_text="file name or folder name", index=True)
  473. location = CharField(max_length=255, null=True, help_text="where dose it store", index=True)
  474. size = IntegerField(default=0, index=True)
  475. type = CharField(max_length=32, null=False, help_text="file extension", index=True)
  476. source_type = CharField(max_length=128, null=False, default="", help_text="where dose this document come from", index=True)
  477. class Meta:
  478. db_table = "file"
  479. class File2Document(DataBaseModel):
  480. id = CharField(max_length=32, primary_key=True)
  481. file_id = CharField(max_length=32, null=True, help_text="file id", index=True)
  482. document_id = CharField(max_length=32, null=True, help_text="document id", index=True)
  483. class Meta:
  484. db_table = "file2document"
  485. class Task(DataBaseModel):
  486. id = CharField(max_length=32, primary_key=True)
  487. doc_id = CharField(max_length=32, null=False, index=True)
  488. from_page = IntegerField(default=0)
  489. to_page = IntegerField(default=100000000)
  490. task_type = CharField(max_length=32, null=False, default="")
  491. priority = IntegerField(default=0)
  492. begin_at = DateTimeField(null=True, index=True)
  493. process_duation = FloatField(default=0)
  494. progress = FloatField(default=0, index=True)
  495. progress_msg = TextField(null=True, help_text="process message", default="")
  496. retry_count = IntegerField(default=0)
  497. digest = TextField(null=True, help_text="task digest", default="")
  498. chunk_ids = LongTextField(null=True, help_text="chunk ids", default="")
  499. class Dialog(DataBaseModel):
  500. id = CharField(max_length=32, primary_key=True)
  501. tenant_id = CharField(max_length=32, null=False, index=True)
  502. name = CharField(max_length=255, null=True, help_text="dialog application name", index=True)
  503. description = TextField(null=True, help_text="Dialog description")
  504. icon = TextField(null=True, help_text="icon base64 string")
  505. language = CharField(max_length=32, null=True, default="Chinese" if "zh_CN" in os.getenv("LANG", "") else "English", help_text="English|Chinese", index=True)
  506. llm_id = CharField(max_length=128, null=False, help_text="default llm ID")
  507. llm_setting = JSONField(null=False, default={"temperature": 0.1, "top_p": 0.3, "frequency_penalty": 0.7, "presence_penalty": 0.4, "max_tokens": 512})
  508. prompt_type = CharField(max_length=16, null=False, default="simple", help_text="simple|advanced", index=True)
  509. prompt_config = JSONField(
  510. null=False,
  511. default={"system": "", "prologue": "Hi! I'm your assistant, what can I do for you?", "parameters": [], "empty_response": "Sorry! No relevant content was found in the knowledge base!"},
  512. )
  513. similarity_threshold = FloatField(default=0.2)
  514. vector_similarity_weight = FloatField(default=0.3)
  515. top_n = IntegerField(default=6)
  516. top_k = IntegerField(default=1024)
  517. do_refer = CharField(max_length=1, null=False, default="1", help_text="it needs to insert reference index into answer or not")
  518. rerank_id = CharField(max_length=128, null=False, help_text="default rerank model ID")
  519. kb_ids = JSONField(null=False, default=[])
  520. status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted, 1: validate)", default="1", index=True)
  521. class Meta:
  522. db_table = "dialog"
  523. class Conversation(DataBaseModel):
  524. id = CharField(max_length=32, primary_key=True)
  525. dialog_id = CharField(max_length=32, null=False, index=True)
  526. name = CharField(max_length=255, null=True, help_text="converastion name", index=True)
  527. message = JSONField(null=True)
  528. reference = JSONField(null=True, default=[])
  529. user_id = CharField(max_length=255, null=True, help_text="user_id", index=True)
  530. class Meta:
  531. db_table = "conversation"
  532. class APIToken(DataBaseModel):
  533. tenant_id = CharField(max_length=32, null=False, index=True)
  534. token = CharField(max_length=255, null=False, index=True)
  535. dialog_id = CharField(max_length=32, null=True, index=True)
  536. source = CharField(max_length=16, null=True, help_text="none|agent|dialog", index=True)
  537. beta = CharField(max_length=255, null=True, index=True)
  538. class Meta:
  539. db_table = "api_token"
  540. primary_key = CompositeKey("tenant_id", "token")
  541. class API4Conversation(DataBaseModel):
  542. id = CharField(max_length=32, primary_key=True)
  543. dialog_id = CharField(max_length=32, null=False, index=True)
  544. user_id = CharField(max_length=255, null=False, help_text="user_id", index=True)
  545. message = JSONField(null=True)
  546. reference = JSONField(null=True, default=[])
  547. tokens = IntegerField(default=0)
  548. source = CharField(max_length=16, null=True, help_text="none|agent|dialog", index=True)
  549. dsl = JSONField(null=True, default={})
  550. duration = FloatField(default=0, index=True)
  551. round = IntegerField(default=0, index=True)
  552. thumb_up = IntegerField(default=0, index=True)
  553. class Meta:
  554. db_table = "api_4_conversation"
  555. class UserCanvas(DataBaseModel):
  556. id = CharField(max_length=32, primary_key=True)
  557. avatar = TextField(null=True, help_text="avatar base64 string")
  558. user_id = CharField(max_length=255, null=False, help_text="user_id", index=True)
  559. title = CharField(max_length=255, null=True, help_text="Canvas title")
  560. permission = CharField(max_length=16, null=False, help_text="me|team", default="me", index=True)
  561. description = TextField(null=True, help_text="Canvas description")
  562. canvas_type = CharField(max_length=32, null=True, help_text="Canvas type", index=True)
  563. dsl = JSONField(null=True, default={})
  564. class Meta:
  565. db_table = "user_canvas"
  566. class CanvasTemplate(DataBaseModel):
  567. id = CharField(max_length=32, primary_key=True)
  568. avatar = TextField(null=True, help_text="avatar base64 string")
  569. title = CharField(max_length=255, null=True, help_text="Canvas title")
  570. description = TextField(null=True, help_text="Canvas description")
  571. canvas_type = CharField(max_length=32, null=True, help_text="Canvas type", index=True)
  572. dsl = JSONField(null=True, default={})
  573. class Meta:
  574. db_table = "canvas_template"
  575. class UserCanvasVersion(DataBaseModel):
  576. id = CharField(max_length=32, primary_key=True)
  577. user_canvas_id = CharField(max_length=255, null=False, help_text="user_canvas_id", index=True)
  578. title = CharField(max_length=255, null=True, help_text="Canvas title")
  579. description = TextField(null=True, help_text="Canvas description")
  580. dsl = JSONField(null=True, default={})
  581. class Meta:
  582. db_table = "user_canvas_version"
  583. def migrate_db():
  584. with DB.transaction():
  585. migrator = DatabaseMigrator[settings.DATABASE_TYPE.upper()].value(DB)
  586. try:
  587. migrate(migrator.add_column("file", "source_type", CharField(max_length=128, null=False, default="", help_text="where dose this document come from", index=True)))
  588. except Exception:
  589. pass
  590. try:
  591. migrate(migrator.add_column("tenant", "rerank_id", CharField(max_length=128, null=False, default="BAAI/bge-reranker-v2-m3", help_text="default rerank model ID")))
  592. except Exception:
  593. pass
  594. try:
  595. migrate(migrator.add_column("dialog", "rerank_id", CharField(max_length=128, null=False, default="", help_text="default rerank model ID")))
  596. except Exception:
  597. pass
  598. try:
  599. migrate(migrator.add_column("dialog", "top_k", IntegerField(default=1024)))
  600. except Exception:
  601. pass
  602. try:
  603. migrate(migrator.alter_column_type("tenant_llm", "api_key", CharField(max_length=2048, null=True, help_text="API KEY", index=True)))
  604. except Exception:
  605. pass
  606. try:
  607. migrate(migrator.add_column("api_token", "source", CharField(max_length=16, null=True, help_text="none|agent|dialog", index=True)))
  608. except Exception:
  609. pass
  610. try:
  611. migrate(migrator.add_column("tenant", "tts_id", CharField(max_length=256, null=True, help_text="default tts model ID", index=True)))
  612. except Exception:
  613. pass
  614. try:
  615. migrate(migrator.add_column("api_4_conversation", "source", CharField(max_length=16, null=True, help_text="none|agent|dialog", index=True)))
  616. except Exception:
  617. pass
  618. try:
  619. migrate(migrator.add_column("task", "retry_count", IntegerField(default=0)))
  620. except Exception:
  621. pass
  622. try:
  623. migrate(migrator.alter_column_type("api_token", "dialog_id", CharField(max_length=32, null=True, index=True)))
  624. except Exception:
  625. pass
  626. try:
  627. migrate(migrator.add_column("tenant_llm", "max_tokens", IntegerField(default=8192, index=True)))
  628. except Exception:
  629. pass
  630. try:
  631. migrate(migrator.add_column("api_4_conversation", "dsl", JSONField(null=True, default={})))
  632. except Exception:
  633. pass
  634. try:
  635. migrate(migrator.add_column("knowledgebase", "pagerank", IntegerField(default=0, index=False)))
  636. except Exception:
  637. pass
  638. try:
  639. migrate(migrator.add_column("api_token", "beta", CharField(max_length=255, null=True, index=True)))
  640. except Exception:
  641. pass
  642. try:
  643. migrate(migrator.add_column("task", "digest", TextField(null=True, help_text="task digest", default="")))
  644. except Exception:
  645. pass
  646. try:
  647. migrate(migrator.add_column("task", "chunk_ids", LongTextField(null=True, help_text="chunk ids", default="")))
  648. except Exception:
  649. pass
  650. try:
  651. migrate(migrator.add_column("conversation", "user_id", CharField(max_length=255, null=True, help_text="user_id", index=True)))
  652. except Exception:
  653. pass
  654. try:
  655. migrate(migrator.add_column("document", "meta_fields", JSONField(null=True, default={})))
  656. except Exception:
  657. pass
  658. try:
  659. migrate(migrator.add_column("task", "task_type", CharField(max_length=32, null=False, default="")))
  660. except Exception:
  661. pass
  662. try:
  663. migrate(migrator.add_column("task", "priority", IntegerField(default=0)))
  664. except Exception:
  665. pass
  666. try:
  667. migrate(migrator.add_column("user_canvas", "permission", CharField(max_length=16, null=False, help_text="me|team", default="me", index=True)))
  668. except Exception:
  669. pass