Du kan inte välja fler än 25 ämnen Ämnen måste starta med en bokstav eller siffra, kan innehålla bindestreck ('-') och vara max 35 tecken långa.

db_models.py 34KB


  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import inspect
  17. import os
  18. import sys
  19. import typing
  20. import operator
  21. from enum import Enum
  22. from functools import wraps
  23. from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
  24. from flask_login import UserMixin
  25. from playhouse.migrate import MySQLMigrator, PostgresqlMigrator, migrate
  26. from peewee import (
  27. BigIntegerField, BooleanField, CharField,
  28. CompositeKey, IntegerField, TextField, FloatField, DateTimeField,
  29. Field, Model, Metadata
  30. )
  31. from playhouse.pool import PooledMySQLDatabase, PooledPostgresqlDatabase
  32. from api.db import SerializedType, ParserType
  33. from api.settings import DATABASE, stat_logger, SECRET_KEY, DATABASE_TYPE
  34. from api.utils.log_utils import getLogger
  35. from api import utils
  36. LOGGER = getLogger()
  37. def singleton(cls, *args, **kw):
  38. instances = {}
  39. def _singleton():
  40. key = str(cls) + str(os.getpid())
  41. if key not in instances:
  42. instances[key] = cls(*args, **kw)
  43. return instances[key]
  44. return _singleton
  45. CONTINUOUS_FIELD_TYPE = {IntegerField, FloatField, DateTimeField}
  46. AUTO_DATE_TIMESTAMP_FIELD_PREFIX = {
  47. "create",
  48. "start",
  49. "end",
  50. "update",
  51. "read_access",
  52. "write_access"}
  53. class TextFieldType(Enum):
  54. MYSQL = 'LONGTEXT'
  55. POSTGRES = 'TEXT'
  56. class LongTextField(TextField):
  57. field_type = TextFieldType[DATABASE_TYPE.upper()].value
  58. class JSONField(LongTextField):
  59. default_value = {}
  60. def __init__(self, object_hook=None, object_pairs_hook=None, **kwargs):
  61. self._object_hook = object_hook
  62. self._object_pairs_hook = object_pairs_hook
  63. super().__init__(**kwargs)
  64. def db_value(self, value):
  65. if value is None:
  66. value = self.default_value
  67. return utils.json_dumps(value)
  68. def python_value(self, value):
  69. if not value:
  70. return self.default_value
  71. return utils.json_loads(
  72. value, object_hook=self._object_hook, object_pairs_hook=self._object_pairs_hook)
  73. class ListField(JSONField):
  74. default_value = []
  75. class SerializedField(LongTextField):
  76. def __init__(self, serialized_type=SerializedType.PICKLE,
  77. object_hook=None, object_pairs_hook=None, **kwargs):
  78. self._serialized_type = serialized_type
  79. self._object_hook = object_hook
  80. self._object_pairs_hook = object_pairs_hook
  81. super().__init__(**kwargs)
  82. def db_value(self, value):
  83. if self._serialized_type == SerializedType.PICKLE:
  84. return utils.serialize_b64(value, to_str=True)
  85. elif self._serialized_type == SerializedType.JSON:
  86. if value is None:
  87. return None
  88. return utils.json_dumps(value, with_type=True)
  89. else:
  90. raise ValueError(
  91. f"the serialized type {self._serialized_type} is not supported")
  92. def python_value(self, value):
  93. if self._serialized_type == SerializedType.PICKLE:
  94. return utils.deserialize_b64(value)
  95. elif self._serialized_type == SerializedType.JSON:
  96. if value is None:
  97. return {}
  98. return utils.json_loads(
  99. value, object_hook=self._object_hook, object_pairs_hook=self._object_pairs_hook)
  100. else:
  101. raise ValueError(
  102. f"the serialized type {self._serialized_type} is not supported")
  103. def is_continuous_field(cls: typing.Type) -> bool:
  104. if cls in CONTINUOUS_FIELD_TYPE:
  105. return True
  106. for p in cls.__bases__:
  107. if p in CONTINUOUS_FIELD_TYPE:
  108. return True
  109. elif p != Field and p != object:
  110. if is_continuous_field(p):
  111. return True
  112. else:
  113. return False
  114. def auto_date_timestamp_field():
  115. return {f"{f}_time" for f in AUTO_DATE_TIMESTAMP_FIELD_PREFIX}
  116. def auto_date_timestamp_db_field():
  117. return {f"f_{f}_time" for f in AUTO_DATE_TIMESTAMP_FIELD_PREFIX}
  118. def remove_field_name_prefix(field_name):
  119. return field_name[2:] if field_name.startswith('f_') else field_name
  120. class BaseModel(Model):
  121. create_time = BigIntegerField(null=True, index=True)
  122. create_date = DateTimeField(null=True, index=True)
  123. update_time = BigIntegerField(null=True, index=True)
  124. update_date = DateTimeField(null=True, index=True)
  125. def to_json(self):
  126. # This function is obsolete
  127. return self.to_dict()
  128. def to_dict(self):
  129. return self.__dict__['__data__']
  130. def to_human_model_dict(self, only_primary_with: list = None):
  131. model_dict = self.__dict__['__data__']
  132. if not only_primary_with:
  133. return {remove_field_name_prefix(
  134. k): v for k, v in model_dict.items()}
  135. human_model_dict = {}
  136. for k in self._meta.primary_key.field_names:
  137. human_model_dict[remove_field_name_prefix(k)] = model_dict[k]
  138. for k in only_primary_with:
  139. human_model_dict[k] = model_dict[f'f_{k}']
  140. return human_model_dict
  141. @property
  142. def meta(self) -> Metadata:
  143. return self._meta
  144. @classmethod
  145. def get_primary_keys_name(cls):
  146. return cls._meta.primary_key.field_names if isinstance(cls._meta.primary_key, CompositeKey) else [
  147. cls._meta.primary_key.name]
  148. @classmethod
  149. def getter_by(cls, attr):
  150. return operator.attrgetter(attr)(cls)
  151. @classmethod
  152. def query(cls, reverse=None, order_by=None, **kwargs):
  153. filters = []
  154. for f_n, f_v in kwargs.items():
  155. attr_name = '%s' % f_n
  156. if not hasattr(cls, attr_name) or f_v is None:
  157. continue
  158. if type(f_v) in {list, set}:
  159. f_v = list(f_v)
  160. if is_continuous_field(type(getattr(cls, attr_name))):
  161. if len(f_v) == 2:
  162. for i, v in enumerate(f_v):
  163. if isinstance(
  164. v, str) and f_n in auto_date_timestamp_field():
  165. # time type: %Y-%m-%d %H:%M:%S
  166. f_v[i] = utils.date_string_to_timestamp(v)
  167. lt_value = f_v[0]
  168. gt_value = f_v[1]
  169. if lt_value is not None and gt_value is not None:
  170. filters.append(
  171. cls.getter_by(attr_name).between(
  172. lt_value, gt_value))
  173. elif lt_value is not None:
  174. filters.append(
  175. operator.attrgetter(attr_name)(cls) >= lt_value)
  176. elif gt_value is not None:
  177. filters.append(
  178. operator.attrgetter(attr_name)(cls) <= gt_value)
  179. else:
  180. filters.append(operator.attrgetter(attr_name)(cls) << f_v)
  181. else:
  182. filters.append(operator.attrgetter(attr_name)(cls) == f_v)
  183. if filters:
  184. query_records = cls.select().where(*filters)
  185. if reverse is not None:
  186. if not order_by or not hasattr(cls, f"{order_by}"):
  187. order_by = "create_time"
  188. if reverse is True:
  189. query_records = query_records.order_by(
  190. cls.getter_by(f"{order_by}").desc())
  191. elif reverse is False:
  192. query_records = query_records.order_by(
  193. cls.getter_by(f"{order_by}").asc())
  194. return [query_record for query_record in query_records]
  195. else:
  196. return []
  197. @classmethod
  198. def insert(cls, __data=None, **insert):
  199. if isinstance(__data, dict) and __data:
  200. __data[cls._meta.combined["create_time"]
  201. ] = utils.current_timestamp()
  202. if insert:
  203. insert["create_time"] = utils.current_timestamp()
  204. return super().insert(__data, **insert)
  205. # update and insert will call this method
  206. @classmethod
  207. def _normalize_data(cls, data, kwargs):
  208. normalized = super()._normalize_data(data, kwargs)
  209. if not normalized:
  210. return {}
  211. normalized[cls._meta.combined["update_time"]
  212. ] = utils.current_timestamp()
  213. for f_n in AUTO_DATE_TIMESTAMP_FIELD_PREFIX:
  214. if {f"{f_n}_time", f"{f_n}_date"}.issubset(cls._meta.combined.keys()) and \
  215. cls._meta.combined[f"{f_n}_time"] in normalized and \
  216. normalized[cls._meta.combined[f"{f_n}_time"]] is not None:
  217. normalized[cls._meta.combined[f"{f_n}_date"]] = utils.timestamp_to_date(
  218. normalized[cls._meta.combined[f"{f_n}_time"]])
  219. return normalized
  220. class JsonSerializedField(SerializedField):
  221. def __init__(self, object_hook=utils.from_dict_hook,
  222. object_pairs_hook=None, **kwargs):
  223. super(JsonSerializedField, self).__init__(serialized_type=SerializedType.JSON, object_hook=object_hook,
  224. object_pairs_hook=object_pairs_hook, **kwargs)
  225. class PooledDatabase(Enum):
  226. MYSQL = PooledMySQLDatabase
  227. POSTGRES = PooledPostgresqlDatabase
  228. class DatabaseMigrator(Enum):
  229. MYSQL = MySQLMigrator
  230. POSTGRES = PostgresqlMigrator
  231. @singleton
  232. class BaseDataBase:
  233. def __init__(self):
  234. database_config = DATABASE.copy()
  235. db_name = database_config.pop("name")
  236. self.database_connection = PooledDatabase[DATABASE_TYPE.upper()].value(db_name, **database_config)
  237. stat_logger.info('init database on cluster mode successfully')
  238. class PostgresDatabaseLock:
  239. def __init__(self, lock_name, timeout=10, db=None):
  240. self.lock_name = lock_name
  241. self.timeout = int(timeout)
  242. self.db = db if db else DB
  243. def lock(self):
  244. cursor = self.db.execute_sql("SELECT pg_try_advisory_lock(%s)", self.timeout)
  245. ret = cursor.fetchone()
  246. if ret[0] == 0:
  247. raise Exception(f'acquire postgres lock {self.lock_name} timeout')
  248. elif ret[0] == 1:
  249. return True
  250. else:
  251. raise Exception(f'failed to acquire lock {self.lock_name}')
  252. def unlock(self):
  253. cursor = self.db.execute_sql("SELECT pg_advisory_unlock(%s)", self.timeout)
  254. ret = cursor.fetchone()
  255. if ret[0] == 0:
  256. raise Exception(
  257. f'postgres lock {self.lock_name} was not established by this thread')
  258. elif ret[0] == 1:
  259. return True
  260. else:
  261. raise Exception(f'postgres lock {self.lock_name} does not exist')
  262. def __enter__(self):
  263. if isinstance(self.db, PostgresDatabaseLock):
  264. self.lock()
  265. return self
  266. def __exit__(self, exc_type, exc_val, exc_tb):
  267. if isinstance(self.db, PostgresDatabaseLock):
  268. self.unlock()
  269. def __call__(self, func):
  270. @wraps(func)
  271. def magic(*args, **kwargs):
  272. with self:
  273. return func(*args, **kwargs)
  274. return magic
  275. class MysqlDatabaseLock:
  276. def __init__(self, lock_name, timeout=10, db=None):
  277. self.lock_name = lock_name
  278. self.timeout = int(timeout)
  279. self.db = db if db else DB
  280. def lock(self):
  281. # SQL parameters only support %s format placeholders
  282. cursor = self.db.execute_sql(
  283. "SELECT GET_LOCK(%s, %s)", (self.lock_name, self.timeout))
  284. ret = cursor.fetchone()
  285. if ret[0] == 0:
  286. raise Exception(f'acquire mysql lock {self.lock_name} timeout')
  287. elif ret[0] == 1:
  288. return True
  289. else:
  290. raise Exception(f'failed to acquire lock {self.lock_name}')
  291. def unlock(self):
  292. cursor = self.db.execute_sql(
  293. "SELECT RELEASE_LOCK(%s)", (self.lock_name,))
  294. ret = cursor.fetchone()
  295. if ret[0] == 0:
  296. raise Exception(
  297. f'mysql lock {self.lock_name} was not established by this thread')
  298. elif ret[0] == 1:
  299. return True
  300. else:
  301. raise Exception(f'mysql lock {self.lock_name} does not exist')
  302. def __enter__(self):
  303. if isinstance(self.db, PooledMySQLDatabase):
  304. self.lock()
  305. return self
  306. def __exit__(self, exc_type, exc_val, exc_tb):
  307. if isinstance(self.db, PooledMySQLDatabase):
  308. self.unlock()
  309. def __call__(self, func):
  310. @wraps(func)
  311. def magic(*args, **kwargs):
  312. with self:
  313. return func(*args, **kwargs)
  314. return magic
  315. class DatabaseLock(Enum):
  316. MYSQL = MysqlDatabaseLock
  317. POSTGRES = PostgresDatabaseLock
  318. DB = BaseDataBase().database_connection
  319. DB.lock = DatabaseLock[DATABASE_TYPE.upper()].value
  320. def close_connection():
  321. try:
  322. if DB:
  323. DB.close_stale(age=30)
  324. except Exception as e:
  325. LOGGER.exception(e)
  326. class DataBaseModel(BaseModel):
  327. class Meta:
  328. database = DB
  329. @DB.connection_context()
  330. def init_database_tables(alter_fields=[]):
  331. members = inspect.getmembers(sys.modules[__name__], inspect.isclass)
  332. table_objs = []
  333. create_failed_list = []
  334. for name, obj in members:
  335. if obj != DataBaseModel and issubclass(obj, DataBaseModel):
  336. table_objs.append(obj)
  337. LOGGER.info(f"start create table {obj.__name__}")
  338. try:
  339. obj.create_table()
  340. LOGGER.info(f"create table success: {obj.__name__}")
  341. except Exception as e:
  342. LOGGER.exception(e)
  343. create_failed_list.append(obj.__name__)
  344. if create_failed_list:
  345. LOGGER.info(f"create tables failed: {create_failed_list}")
  346. raise Exception(f"create tables failed: {create_failed_list}")
  347. migrate_db()
  348. def fill_db_model_object(model_object, human_model_dict):
  349. for k, v in human_model_dict.items():
  350. attr_name = '%s' % k
  351. if hasattr(model_object.__class__, attr_name):
  352. setattr(model_object, attr_name, v)
  353. return model_object
  354. class User(DataBaseModel, UserMixin):
  355. id = CharField(max_length=32, primary_key=True)
  356. access_token = CharField(max_length=255, null=True, index=True)
  357. nickname = CharField(max_length=100, null=False, help_text="nicky name", index=True)
  358. password = CharField(max_length=255, null=True, help_text="password", index=True)
  359. email = CharField(
  360. max_length=255,
  361. null=False,
  362. help_text="email",
  363. index=True)
  364. avatar = TextField(null=True, help_text="avatar base64 string")
  365. language = CharField(
  366. max_length=32,
  367. null=True,
  368. help_text="English|Chinese",
  369. default="Chinese" if "zh_CN" in os.getenv("LANG", "") else "English",
  370. index=True)
  371. color_schema = CharField(
  372. max_length=32,
  373. null=True,
  374. help_text="Bright|Dark",
  375. default="Bright",
  376. index=True)
  377. timezone = CharField(
  378. max_length=64,
  379. null=True,
  380. help_text="Timezone",
  381. default="UTC+8\tAsia/Shanghai",
  382. index=True)
  383. last_login_time = DateTimeField(null=True, index=True)
  384. is_authenticated = CharField(max_length=1, null=False, default="1", index=True)
  385. is_active = CharField(max_length=1, null=False, default="1", index=True)
  386. is_anonymous = CharField(max_length=1, null=False, default="0", index=True)
  387. login_channel = CharField(null=True, help_text="from which user login", index=True)
  388. status = CharField(
  389. max_length=1,
  390. null=True,
  391. help_text="is it validate(0: wasted,1: validate)",
  392. default="1",
  393. index=True)
  394. is_superuser = BooleanField(null=True, help_text="is root", default=False, index=True)
  395. def __str__(self):
  396. return self.email
  397. def get_id(self):
  398. jwt = Serializer(secret_key=SECRET_KEY)
  399. return jwt.dumps(str(self.access_token))
  400. class Meta:
  401. db_table = "user"
  402. class Tenant(DataBaseModel):
  403. id = CharField(max_length=32, primary_key=True)
  404. name = CharField(max_length=100, null=True, help_text="Tenant name", index=True)
  405. public_key = CharField(max_length=255, null=True, index=True)
  406. llm_id = CharField(max_length=128, null=False, help_text="default llm ID", index=True)
  407. embd_id = CharField(
  408. max_length=128,
  409. null=False,
  410. help_text="default embedding model ID",
  411. index=True)
  412. asr_id = CharField(
  413. max_length=128,
  414. null=False,
  415. help_text="default ASR model ID",
  416. index=True)
  417. img2txt_id = CharField(
  418. max_length=128,
  419. null=False,
  420. help_text="default image to text model ID",
  421. index=True)
  422. rerank_id = CharField(
  423. max_length=128,
  424. null=False,
  425. help_text="default rerank model ID",
  426. index=True)
  427. tts_id = CharField(
  428. max_length=256,
  429. null=True,
  430. help_text="default tts model ID",
  431. index=True)
  432. parser_ids = CharField(
  433. max_length=256,
  434. null=False,
  435. help_text="document processors",
  436. index=True)
  437. credit = IntegerField(default=512, index=True)
  438. status = CharField(
  439. max_length=1,
  440. null=True,
  441. help_text="is it validate(0: wasted,1: validate)",
  442. default="1",
  443. index=True)
  444. class Meta:
  445. db_table = "tenant"
  446. class UserTenant(DataBaseModel):
  447. id = CharField(max_length=32, primary_key=True)
  448. user_id = CharField(max_length=32, null=False, index=True)
  449. tenant_id = CharField(max_length=32, null=False, index=True)
  450. role = CharField(max_length=32, null=False, help_text="UserTenantRole", index=True)
  451. invited_by = CharField(max_length=32, null=False, index=True)
  452. status = CharField(
  453. max_length=1,
  454. null=True,
  455. help_text="is it validate(0: wasted,1: validate)",
  456. default="1",
  457. index=True)
  458. class Meta:
  459. db_table = "user_tenant"
  460. class InvitationCode(DataBaseModel):
  461. id = CharField(max_length=32, primary_key=True)
  462. code = CharField(max_length=32, null=False, index=True)
  463. visit_time = DateTimeField(null=True, index=True)
  464. user_id = CharField(max_length=32, null=True, index=True)
  465. tenant_id = CharField(max_length=32, null=True, index=True)
  466. status = CharField(
  467. max_length=1,
  468. null=True,
  469. help_text="is it validate(0: wasted,1: validate)",
  470. default="1",
  471. index=True)
  472. class Meta:
  473. db_table = "invitation_code"
  474. class LLMFactories(DataBaseModel):
  475. name = CharField(
  476. max_length=128,
  477. null=False,
  478. help_text="LLM factory name",
  479. primary_key=True)
  480. logo = TextField(null=True, help_text="llm logo base64")
  481. tags = CharField(
  482. max_length=255,
  483. null=False,
  484. help_text="LLM, Text Embedding, Image2Text, ASR",
  485. index=True)
  486. status = CharField(
  487. max_length=1,
  488. null=True,
  489. help_text="is it validate(0: wasted,1: validate)",
  490. default="1",
  491. index=True)
  492. def __str__(self):
  493. return self.name
  494. class Meta:
  495. db_table = "llm_factories"
  496. class LLM(DataBaseModel):
  497. # LLMs dictionary
  498. llm_name = CharField(
  499. max_length=128,
  500. null=False,
  501. help_text="LLM name",
  502. index=True)
  503. model_type = CharField(
  504. max_length=128,
  505. null=False,
  506. help_text="LLM, Text Embedding, Image2Text, ASR",
  507. index=True)
  508. fid = CharField(max_length=128, null=False, help_text="LLM factory id", index=True)
  509. max_tokens = IntegerField(default=0)
  510. tags = CharField(
  511. max_length=255,
  512. null=False,
  513. help_text="LLM, Text Embedding, Image2Text, Chat, 32k...",
  514. index=True)
  515. status = CharField(
  516. max_length=1,
  517. null=True,
  518. help_text="is it validate(0: wasted,1: validate)",
  519. default="1",
  520. index=True)
  521. def __str__(self):
  522. return self.llm_name
  523. class Meta:
  524. primary_key = CompositeKey('fid', 'llm_name')
  525. db_table = "llm"
  526. class TenantLLM(DataBaseModel):
  527. tenant_id = CharField(max_length=32, null=False, index=True)
  528. llm_factory = CharField(
  529. max_length=128,
  530. null=False,
  531. help_text="LLM factory name",
  532. index=True)
  533. model_type = CharField(
  534. max_length=128,
  535. null=True,
  536. help_text="LLM, Text Embedding, Image2Text, ASR",
  537. index=True)
  538. llm_name = CharField(
  539. max_length=128,
  540. null=True,
  541. help_text="LLM name",
  542. default="",
  543. index=True)
  544. api_key = CharField(max_length=1024, null=True, help_text="API KEY", index=True)
  545. api_base = CharField(max_length=255, null=True, help_text="API Base")
  546. used_tokens = IntegerField(default=0, index=True)
  547. def __str__(self):
  548. return self.llm_name
  549. class Meta:
  550. db_table = "tenant_llm"
  551. primary_key = CompositeKey('tenant_id', 'llm_factory', 'llm_name')
  552. class Knowledgebase(DataBaseModel):
  553. id = CharField(max_length=32, primary_key=True)
  554. avatar = TextField(null=True, help_text="avatar base64 string")
  555. tenant_id = CharField(max_length=32, null=False, index=True)
  556. name = CharField(
  557. max_length=128,
  558. null=False,
  559. help_text="KB name",
  560. index=True)
  561. language = CharField(
  562. max_length=32,
  563. null=True,
  564. default="Chinese" if "zh_CN" in os.getenv("LANG", "") else "English",
  565. help_text="English|Chinese",
  566. index=True)
  567. description = TextField(null=True, help_text="KB description")
  568. embd_id = CharField(
  569. max_length=128,
  570. null=False,
  571. help_text="default embedding model ID",
  572. index=True)
  573. permission = CharField(
  574. max_length=16,
  575. null=False,
  576. help_text="me|team",
  577. default="me",
  578. index=True)
  579. created_by = CharField(max_length=32, null=False, index=True)
  580. doc_num = IntegerField(default=0, index=True)
  581. token_num = IntegerField(default=0, index=True)
  582. chunk_num = IntegerField(default=0, index=True)
  583. similarity_threshold = FloatField(default=0.2, index=True)
  584. vector_similarity_weight = FloatField(default=0.3, index=True)
  585. parser_id = CharField(
  586. max_length=32,
  587. null=False,
  588. help_text="default parser ID",
  589. default=ParserType.NAIVE.value,
  590. index=True)
  591. parser_config = JSONField(null=False, default={"pages": [[1, 1000000]]})
  592. status = CharField(
  593. max_length=1,
  594. null=True,
  595. help_text="is it validate(0: wasted,1: validate)",
  596. default="1",
  597. index=True)
  598. def __str__(self):
  599. return self.name
  600. class Meta:
  601. db_table = "knowledgebase"
  602. class Document(DataBaseModel):
  603. id = CharField(max_length=32, primary_key=True)
  604. thumbnail = TextField(null=True, help_text="thumbnail base64 string")
  605. kb_id = CharField(max_length=256, null=False, index=True)
  606. parser_id = CharField(
  607. max_length=32,
  608. null=False,
  609. help_text="default parser ID",
  610. index=True)
  611. parser_config = JSONField(null=False, default={"pages": [[1, 1000000]]})
  612. source_type = CharField(
  613. max_length=128,
  614. null=False,
  615. default="local",
  616. help_text="where dose this document come from",
  617. index=True)
  618. type = CharField(max_length=32, null=False, help_text="file extension",
  619. index=True)
  620. created_by = CharField(
  621. max_length=32,
  622. null=False,
  623. help_text="who created it",
  624. index=True)
  625. name = CharField(
  626. max_length=255,
  627. null=True,
  628. help_text="file name",
  629. index=True)
  630. location = CharField(
  631. max_length=255,
  632. null=True,
  633. help_text="where dose it store",
  634. index=True)
  635. size = IntegerField(default=0, index=True)
  636. token_num = IntegerField(default=0, index=True)
  637. chunk_num = IntegerField(default=0, index=True)
  638. progress = FloatField(default=0, index=True)
  639. progress_msg = TextField(
  640. null=True,
  641. help_text="process message",
  642. default="")
  643. process_begin_at = DateTimeField(null=True, index=True)
  644. process_duation = FloatField(default=0)
  645. run = CharField(
  646. max_length=1,
  647. null=True,
  648. help_text="start to run processing or cancel.(1: run it; 2: cancel)",
  649. default="0",
  650. index=True)
  651. status = CharField(
  652. max_length=1,
  653. null=True,
  654. help_text="is it validate(0: wasted,1: validate)",
  655. default="1",
  656. index=True)
  657. class Meta:
  658. db_table = "document"
  659. class File(DataBaseModel):
  660. id = CharField(
  661. max_length=32,
  662. primary_key=True)
  663. parent_id = CharField(
  664. max_length=32,
  665. null=False,
  666. help_text="parent folder id",
  667. index=True)
  668. tenant_id = CharField(
  669. max_length=32,
  670. null=False,
  671. help_text="tenant id",
  672. index=True)
  673. created_by = CharField(
  674. max_length=32,
  675. null=False,
  676. help_text="who created it",
  677. index=True)
  678. name = CharField(
  679. max_length=255,
  680. null=False,
  681. help_text="file name or folder name",
  682. index=True)
  683. location = CharField(
  684. max_length=255,
  685. null=True,
  686. help_text="where dose it store",
  687. index=True)
  688. size = IntegerField(default=0, index=True)
  689. type = CharField(max_length=32, null=False, help_text="file extension", index=True)
  690. source_type = CharField(
  691. max_length=128,
  692. null=False,
  693. default="",
  694. help_text="where dose this document come from", index=True)
  695. class Meta:
  696. db_table = "file"
  697. class File2Document(DataBaseModel):
  698. id = CharField(
  699. max_length=32,
  700. primary_key=True)
  701. file_id = CharField(
  702. max_length=32,
  703. null=True,
  704. help_text="file id",
  705. index=True)
  706. document_id = CharField(
  707. max_length=32,
  708. null=True,
  709. help_text="document id",
  710. index=True)
  711. class Meta:
  712. db_table = "file2document"
  713. class Task(DataBaseModel):
  714. id = CharField(max_length=32, primary_key=True)
  715. doc_id = CharField(max_length=32, null=False, index=True)
  716. from_page = IntegerField(default=0)
  717. to_page = IntegerField(default=-1)
  718. begin_at = DateTimeField(null=True, index=True)
  719. process_duation = FloatField(default=0)
  720. progress = FloatField(default=0, index=True)
  721. progress_msg = TextField(
  722. null=True,
  723. help_text="process message",
  724. default="")
  725. retry_count = IntegerField(default=0)
  726. class Dialog(DataBaseModel):
  727. id = CharField(max_length=32, primary_key=True)
  728. tenant_id = CharField(max_length=32, null=False, index=True)
  729. name = CharField(
  730. max_length=255,
  731. null=True,
  732. help_text="dialog application name",
  733. index=True)
  734. description = TextField(null=True, help_text="Dialog description")
  735. icon = TextField(null=True, help_text="icon base64 string")
  736. language = CharField(
  737. max_length=32,
  738. null=True,
  739. default="Chinese" if "zh_CN" in os.getenv("LANG", "") else "English",
  740. help_text="English|Chinese",
  741. index=True)
  742. llm_id = CharField(max_length=128, null=False, help_text="default llm ID")
  743. llm_setting = JSONField(null=False, default={"temperature": 0.1, "top_p": 0.3, "frequency_penalty": 0.7,
  744. "presence_penalty": 0.4, "max_tokens": 512})
  745. prompt_type = CharField(
  746. max_length=16,
  747. null=False,
  748. default="simple",
  749. help_text="simple|advanced",
  750. index=True)
  751. prompt_config = JSONField(null=False, default={"system": "", "prologue": "Hi! I'm your assistant, what can I do for you?",
  752. "parameters": [], "empty_response": "Sorry! No relevant content was found in the knowledge base!"})
  753. similarity_threshold = FloatField(default=0.2)
  754. vector_similarity_weight = FloatField(default=0.3)
  755. top_n = IntegerField(default=6)
  756. top_k = IntegerField(default=1024)
  757. do_refer = CharField(
  758. max_length=1,
  759. null=False,
  760. default="1",
  761. help_text="it needs to insert reference index into answer or not")
  762. rerank_id = CharField(
  763. max_length=128,
  764. null=False,
  765. help_text="default rerank model ID")
  766. kb_ids = JSONField(null=False, default=[])
  767. status = CharField(
  768. max_length=1,
  769. null=True,
  770. help_text="is it validate(0: wasted,1: validate)",
  771. default="1",
  772. index=True)
  773. class Meta:
  774. db_table = "dialog"
  775. class Conversation(DataBaseModel):
  776. id = CharField(max_length=32, primary_key=True)
  777. dialog_id = CharField(max_length=32, null=False, index=True)
  778. name = CharField(max_length=255, null=True, help_text="converastion name", index=True)
  779. message = JSONField(null=True)
  780. reference = JSONField(null=True, default=[])
  781. class Meta:
  782. db_table = "conversation"
  783. class APIToken(DataBaseModel):
  784. tenant_id = CharField(max_length=32, null=False, index=True)
  785. token = CharField(max_length=255, null=False, index=True)
  786. dialog_id = CharField(max_length=32, null=False, index=True)
  787. source = CharField(max_length=16, null=True, help_text="none|agent|dialog", index=True)
  788. class Meta:
  789. db_table = "api_token"
  790. primary_key = CompositeKey('tenant_id', 'token')
  791. class API4Conversation(DataBaseModel):
  792. id = CharField(max_length=32, primary_key=True)
  793. dialog_id = CharField(max_length=32, null=False, index=True)
  794. user_id = CharField(max_length=255, null=False, help_text="user_id", index=True)
  795. message = JSONField(null=True)
  796. reference = JSONField(null=True, default=[])
  797. tokens = IntegerField(default=0)
  798. source = CharField(max_length=16, null=True, help_text="none|agent|dialog", index=True)
  799. duration = FloatField(default=0, index=True)
  800. round = IntegerField(default=0, index=True)
  801. thumb_up = IntegerField(default=0, index=True)
  802. class Meta:
  803. db_table = "api_4_conversation"
  804. class UserCanvas(DataBaseModel):
  805. id = CharField(max_length=32, primary_key=True)
  806. avatar = TextField(null=True, help_text="avatar base64 string")
  807. user_id = CharField(max_length=255, null=False, help_text="user_id", index=True)
  808. title = CharField(max_length=255, null=True, help_text="Canvas title")
  809. description = TextField(null=True, help_text="Canvas description")
  810. canvas_type = CharField(max_length=32, null=True, help_text="Canvas type", index=True)
  811. dsl = JSONField(null=True, default={})
  812. class Meta:
  813. db_table = "user_canvas"
  814. class CanvasTemplate(DataBaseModel):
  815. id = CharField(max_length=32, primary_key=True)
  816. avatar = TextField(null=True, help_text="avatar base64 string")
  817. title = CharField(max_length=255, null=True, help_text="Canvas title")
  818. description = TextField(null=True, help_text="Canvas description")
  819. canvas_type = CharField(max_length=32, null=True, help_text="Canvas type", index=True)
  820. dsl = JSONField(null=True, default={})
  821. class Meta:
  822. db_table = "canvas_template"
  823. def migrate_db():
  824. with DB.transaction():
  825. migrator = DatabaseMigrator[DATABASE_TYPE.upper()].value(DB)
  826. try:
  827. migrate(
  828. migrator.add_column('file', 'source_type', CharField(max_length=128, null=False, default="",
  829. help_text="where dose this document come from",
  830. index=True))
  831. )
  832. except Exception as e:
  833. pass
  834. try:
  835. migrate(
  836. migrator.add_column('tenant', 'rerank_id',
  837. CharField(max_length=128, null=False, default="BAAI/bge-reranker-v2-m3",
  838. help_text="default rerank model ID"))
  839. )
  840. except Exception as e:
  841. pass
  842. try:
  843. migrate(
  844. migrator.add_column('dialog', 'rerank_id', CharField(max_length=128, null=False, default="",
  845. help_text="default rerank model ID"))
  846. )
  847. except Exception as e:
  848. pass
  849. try:
  850. migrate(
  851. migrator.add_column('dialog', 'top_k', IntegerField(default=1024))
  852. )
  853. except Exception as e:
  854. pass
  855. try:
  856. migrate(
  857. migrator.alter_column_type('tenant_llm', 'api_key',
  858. CharField(max_length=1024, null=True, help_text="API KEY", index=True))
  859. )
  860. except Exception as e:
  861. pass
  862. try:
  863. migrate(
  864. migrator.add_column('api_token', 'source',
  865. CharField(max_length=16, null=True, help_text="none|agent|dialog", index=True))
  866. )
  867. except Exception as e:
  868. pass
  869. try:
  870. migrate(
  871. migrator.add_column("tenant","tts_id",
  872. CharField(max_length=256,null=True,help_text="default tts model ID",index=True))
  873. )
  874. except Exception as e:
  875. pass
  876. try:
  877. migrate(
  878. migrator.add_column('api_4_conversation', 'source',
  879. CharField(max_length=16, null=True, help_text="none|agent|dialog", index=True))
  880. )
  881. except Exception as e:
  882. pass
  883. try:
  884. DB.execute_sql('ALTER TABLE llm DROP PRIMARY KEY;')
  885. DB.execute_sql('ALTER TABLE llm ADD PRIMARY KEY (llm_name,fid);')
  886. except Exception as e:
  887. pass
  888. try:
  889. migrate(
  890. migrator.add_column('task', 'retry_count', IntegerField(default=0))
  891. )
  892. except Exception as e:
  893. pass
  894. try:
  895. migrate(
  896. migrator.alter_column_type('api_token', 'dialog_id',
  897. CharField(max_length=32, null=True, index=True))
  898. )
  899. except Exception as e:
  900. pass