選択できるのは25トピックまでです。 トピックは、先頭が英数字で、英数字とダッシュ('-')を使用した35文字以内のものにしてください。

db_models.py 36KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import logging
  17. import inspect
  18. import os
  19. import sys
  20. import typing
  21. import operator
  22. from enum import Enum
  23. from functools import wraps
  24. from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
  25. from flask_login import UserMixin
  26. from playhouse.migrate import MySQLMigrator, PostgresqlMigrator, migrate
  27. from peewee import (
  28. BigIntegerField, BooleanField, CharField,
  29. CompositeKey, IntegerField, TextField, FloatField, DateTimeField,
  30. Field, Model, Metadata
  31. )
  32. from playhouse.pool import PooledMySQLDatabase, PooledPostgresqlDatabase
  33. from api.db import SerializedType, ParserType
  34. from api import settings
  35. from api import utils
  36. def singleton(cls, *args, **kw):
  37. instances = {}
  38. def _singleton():
  39. key = str(cls) + str(os.getpid())
  40. if key not in instances:
  41. instances[key] = cls(*args, **kw)
  42. return instances[key]
  43. return _singleton
  44. CONTINUOUS_FIELD_TYPE = {IntegerField, FloatField, DateTimeField}
  45. AUTO_DATE_TIMESTAMP_FIELD_PREFIX = {
  46. "create",
  47. "start",
  48. "end",
  49. "update",
  50. "read_access",
  51. "write_access"}
  52. class TextFieldType(Enum):
  53. MYSQL = 'LONGTEXT'
  54. POSTGRES = 'TEXT'
  55. class LongTextField(TextField):
  56. field_type = TextFieldType[settings.DATABASE_TYPE.upper()].value
  57. class JSONField(LongTextField):
  58. default_value = {}
  59. def __init__(self, object_hook=None, object_pairs_hook=None, **kwargs):
  60. self._object_hook = object_hook
  61. self._object_pairs_hook = object_pairs_hook
  62. super().__init__(**kwargs)
  63. def db_value(self, value):
  64. if value is None:
  65. value = self.default_value
  66. return utils.json_dumps(value)
  67. def python_value(self, value):
  68. if not value:
  69. return self.default_value
  70. return utils.json_loads(
  71. value, object_hook=self._object_hook, object_pairs_hook=self._object_pairs_hook)
  72. class ListField(JSONField):
  73. default_value = []
  74. class SerializedField(LongTextField):
  75. def __init__(self, serialized_type=SerializedType.PICKLE,
  76. object_hook=None, object_pairs_hook=None, **kwargs):
  77. self._serialized_type = serialized_type
  78. self._object_hook = object_hook
  79. self._object_pairs_hook = object_pairs_hook
  80. super().__init__(**kwargs)
  81. def db_value(self, value):
  82. if self._serialized_type == SerializedType.PICKLE:
  83. return utils.serialize_b64(value, to_str=True)
  84. elif self._serialized_type == SerializedType.JSON:
  85. if value is None:
  86. return None
  87. return utils.json_dumps(value, with_type=True)
  88. else:
  89. raise ValueError(
  90. f"the serialized type {self._serialized_type} is not supported")
  91. def python_value(self, value):
  92. if self._serialized_type == SerializedType.PICKLE:
  93. return utils.deserialize_b64(value)
  94. elif self._serialized_type == SerializedType.JSON:
  95. if value is None:
  96. return {}
  97. return utils.json_loads(
  98. value, object_hook=self._object_hook, object_pairs_hook=self._object_pairs_hook)
  99. else:
  100. raise ValueError(
  101. f"the serialized type {self._serialized_type} is not supported")
  102. def is_continuous_field(cls: typing.Type) -> bool:
  103. if cls in CONTINUOUS_FIELD_TYPE:
  104. return True
  105. for p in cls.__bases__:
  106. if p in CONTINUOUS_FIELD_TYPE:
  107. return True
  108. elif p is not Field and p is not object:
  109. if is_continuous_field(p):
  110. return True
  111. else:
  112. return False
  113. def auto_date_timestamp_field():
  114. return {f"{f}_time" for f in AUTO_DATE_TIMESTAMP_FIELD_PREFIX}
  115. def auto_date_timestamp_db_field():
  116. return {f"f_{f}_time" for f in AUTO_DATE_TIMESTAMP_FIELD_PREFIX}
  117. def remove_field_name_prefix(field_name):
  118. return field_name[2:] if field_name.startswith('f_') else field_name
  119. class BaseModel(Model):
  120. create_time = BigIntegerField(null=True, index=True)
  121. create_date = DateTimeField(null=True, index=True)
  122. update_time = BigIntegerField(null=True, index=True)
  123. update_date = DateTimeField(null=True, index=True)
  124. def to_json(self):
  125. # This function is obsolete
  126. return self.to_dict()
  127. def to_dict(self):
  128. return self.__dict__['__data__']
  129. def to_human_model_dict(self, only_primary_with: list = None):
  130. model_dict = self.__dict__['__data__']
  131. if not only_primary_with:
  132. return {remove_field_name_prefix(
  133. k): v for k, v in model_dict.items()}
  134. human_model_dict = {}
  135. for k in self._meta.primary_key.field_names:
  136. human_model_dict[remove_field_name_prefix(k)] = model_dict[k]
  137. for k in only_primary_with:
  138. human_model_dict[k] = model_dict[f'f_{k}']
  139. return human_model_dict
  140. @property
  141. def meta(self) -> Metadata:
  142. return self._meta
  143. @classmethod
  144. def get_primary_keys_name(cls):
  145. return cls._meta.primary_key.field_names if isinstance(cls._meta.primary_key, CompositeKey) else [
  146. cls._meta.primary_key.name]
  147. @classmethod
  148. def getter_by(cls, attr):
  149. return operator.attrgetter(attr)(cls)
  150. @classmethod
  151. def query(cls, reverse=None, order_by=None, **kwargs):
  152. filters = []
  153. for f_n, f_v in kwargs.items():
  154. attr_name = '%s' % f_n
  155. if not hasattr(cls, attr_name) or f_v is None:
  156. continue
  157. if type(f_v) in {list, set}:
  158. f_v = list(f_v)
  159. if is_continuous_field(type(getattr(cls, attr_name))):
  160. if len(f_v) == 2:
  161. for i, v in enumerate(f_v):
  162. if isinstance(
  163. v, str) and f_n in auto_date_timestamp_field():
  164. # time type: %Y-%m-%d %H:%M:%S
  165. f_v[i] = utils.date_string_to_timestamp(v)
  166. lt_value = f_v[0]
  167. gt_value = f_v[1]
  168. if lt_value is not None and gt_value is not None:
  169. filters.append(
  170. cls.getter_by(attr_name).between(
  171. lt_value, gt_value))
  172. elif lt_value is not None:
  173. filters.append(
  174. operator.attrgetter(attr_name)(cls) >= lt_value)
  175. elif gt_value is not None:
  176. filters.append(
  177. operator.attrgetter(attr_name)(cls) <= gt_value)
  178. else:
  179. filters.append(operator.attrgetter(attr_name)(cls) << f_v)
  180. else:
  181. filters.append(operator.attrgetter(attr_name)(cls) == f_v)
  182. if filters:
  183. query_records = cls.select().where(*filters)
  184. if reverse is not None:
  185. if not order_by or not hasattr(cls, f"{order_by}"):
  186. order_by = "create_time"
  187. if reverse is True:
  188. query_records = query_records.order_by(
  189. cls.getter_by(f"{order_by}").desc())
  190. elif reverse is False:
  191. query_records = query_records.order_by(
  192. cls.getter_by(f"{order_by}").asc())
  193. return [query_record for query_record in query_records]
  194. else:
  195. return []
  196. @classmethod
  197. def insert(cls, __data=None, **insert):
  198. if isinstance(__data, dict) and __data:
  199. __data[cls._meta.combined["create_time"]
  200. ] = utils.current_timestamp()
  201. if insert:
  202. insert["create_time"] = utils.current_timestamp()
  203. return super().insert(__data, **insert)
  204. # update and insert will call this method
  205. @classmethod
  206. def _normalize_data(cls, data, kwargs):
  207. normalized = super()._normalize_data(data, kwargs)
  208. if not normalized:
  209. return {}
  210. normalized[cls._meta.combined["update_time"]
  211. ] = utils.current_timestamp()
  212. for f_n in AUTO_DATE_TIMESTAMP_FIELD_PREFIX:
  213. if {f"{f_n}_time", f"{f_n}_date"}.issubset(cls._meta.combined.keys()) and \
  214. cls._meta.combined[f"{f_n}_time"] in normalized and \
  215. normalized[cls._meta.combined[f"{f_n}_time"]] is not None:
  216. normalized[cls._meta.combined[f"{f_n}_date"]] = utils.timestamp_to_date(
  217. normalized[cls._meta.combined[f"{f_n}_time"]])
  218. return normalized
  219. class JsonSerializedField(SerializedField):
  220. def __init__(self, object_hook=utils.from_dict_hook,
  221. object_pairs_hook=None, **kwargs):
  222. super(JsonSerializedField, self).__init__(serialized_type=SerializedType.JSON, object_hook=object_hook,
  223. object_pairs_hook=object_pairs_hook, **kwargs)
  224. class PooledDatabase(Enum):
  225. MYSQL = PooledMySQLDatabase
  226. POSTGRES = PooledPostgresqlDatabase
  227. class DatabaseMigrator(Enum):
  228. MYSQL = MySQLMigrator
  229. POSTGRES = PostgresqlMigrator
  230. @singleton
  231. class BaseDataBase:
  232. def __init__(self):
  233. database_config = settings.DATABASE.copy()
  234. db_name = database_config.pop("name")
  235. self.database_connection = PooledDatabase[settings.DATABASE_TYPE.upper()].value(db_name, **database_config)
  236. logging.info('init database on cluster mode successfully')
  237. class PostgresDatabaseLock:
  238. def __init__(self, lock_name, timeout=10, db=None):
  239. self.lock_name = lock_name
  240. self.timeout = int(timeout)
  241. self.db = db if db else DB
  242. def lock(self):
  243. cursor = self.db.execute_sql("SELECT pg_try_advisory_lock(%s)", self.timeout)
  244. ret = cursor.fetchone()
  245. if ret[0] == 0:
  246. raise Exception(f'acquire postgres lock {self.lock_name} timeout')
  247. elif ret[0] == 1:
  248. return True
  249. else:
  250. raise Exception(f'failed to acquire lock {self.lock_name}')
  251. def unlock(self):
  252. cursor = self.db.execute_sql("SELECT pg_advisory_unlock(%s)", self.timeout)
  253. ret = cursor.fetchone()
  254. if ret[0] == 0:
  255. raise Exception(
  256. f'postgres lock {self.lock_name} was not established by this thread')
  257. elif ret[0] == 1:
  258. return True
  259. else:
  260. raise Exception(f'postgres lock {self.lock_name} does not exist')
  261. def __enter__(self):
  262. if isinstance(self.db, PostgresDatabaseLock):
  263. self.lock()
  264. return self
  265. def __exit__(self, exc_type, exc_val, exc_tb):
  266. if isinstance(self.db, PostgresDatabaseLock):
  267. self.unlock()
  268. def __call__(self, func):
  269. @wraps(func)
  270. def magic(*args, **kwargs):
  271. with self:
  272. return func(*args, **kwargs)
  273. return magic
  274. class MysqlDatabaseLock:
  275. def __init__(self, lock_name, timeout=10, db=None):
  276. self.lock_name = lock_name
  277. self.timeout = int(timeout)
  278. self.db = db if db else DB
  279. def lock(self):
  280. # SQL parameters only support %s format placeholders
  281. cursor = self.db.execute_sql(
  282. "SELECT GET_LOCK(%s, %s)", (self.lock_name, self.timeout))
  283. ret = cursor.fetchone()
  284. if ret[0] == 0:
  285. raise Exception(f'acquire mysql lock {self.lock_name} timeout')
  286. elif ret[0] == 1:
  287. return True
  288. else:
  289. raise Exception(f'failed to acquire lock {self.lock_name}')
  290. def unlock(self):
  291. cursor = self.db.execute_sql(
  292. "SELECT RELEASE_LOCK(%s)", (self.lock_name,))
  293. ret = cursor.fetchone()
  294. if ret[0] == 0:
  295. raise Exception(
  296. f'mysql lock {self.lock_name} was not established by this thread')
  297. elif ret[0] == 1:
  298. return True
  299. else:
  300. raise Exception(f'mysql lock {self.lock_name} does not exist')
  301. def __enter__(self):
  302. if isinstance(self.db, PooledMySQLDatabase):
  303. self.lock()
  304. return self
  305. def __exit__(self, exc_type, exc_val, exc_tb):
  306. if isinstance(self.db, PooledMySQLDatabase):
  307. self.unlock()
  308. def __call__(self, func):
  309. @wraps(func)
  310. def magic(*args, **kwargs):
  311. with self:
  312. return func(*args, **kwargs)
  313. return magic
  314. class DatabaseLock(Enum):
  315. MYSQL = MysqlDatabaseLock
  316. POSTGRES = PostgresDatabaseLock
  317. DB = BaseDataBase().database_connection
  318. DB.lock = DatabaseLock[settings.DATABASE_TYPE.upper()].value
  319. def close_connection():
  320. try:
  321. if DB:
  322. DB.close_stale(age=30)
  323. except Exception as e:
  324. logging.exception(e)
  325. class DataBaseModel(BaseModel):
  326. class Meta:
  327. database = DB
  328. @DB.connection_context()
  329. def init_database_tables(alter_fields=[]):
  330. members = inspect.getmembers(sys.modules[__name__], inspect.isclass)
  331. table_objs = []
  332. create_failed_list = []
  333. for name, obj in members:
  334. if obj != DataBaseModel and issubclass(obj, DataBaseModel):
  335. table_objs.append(obj)
  336. logging.debug(f"start create table {obj.__name__}")
  337. try:
  338. obj.create_table()
  339. logging.debug(f"create table success: {obj.__name__}")
  340. except Exception as e:
  341. logging.exception(e)
  342. create_failed_list.append(obj.__name__)
  343. if create_failed_list:
  344. logging.error(f"create tables failed: {create_failed_list}")
  345. raise Exception(f"create tables failed: {create_failed_list}")
  346. migrate_db()
  347. def fill_db_model_object(model_object, human_model_dict):
  348. for k, v in human_model_dict.items():
  349. attr_name = '%s' % k
  350. if hasattr(model_object.__class__, attr_name):
  351. setattr(model_object, attr_name, v)
  352. return model_object
  353. class User(DataBaseModel, UserMixin):
  354. id = CharField(max_length=32, primary_key=True)
  355. access_token = CharField(max_length=255, null=True, index=True)
  356. nickname = CharField(max_length=100, null=False, help_text="nicky name", index=True)
  357. password = CharField(max_length=255, null=True, help_text="password", index=True)
  358. email = CharField(
  359. max_length=255,
  360. null=False,
  361. help_text="email",
  362. index=True)
  363. avatar = TextField(null=True, help_text="avatar base64 string")
  364. language = CharField(
  365. max_length=32,
  366. null=True,
  367. help_text="English|Chinese",
  368. default="Chinese" if "zh_CN" in os.getenv("LANG", "") else "English",
  369. index=True)
  370. color_schema = CharField(
  371. max_length=32,
  372. null=True,
  373. help_text="Bright|Dark",
  374. default="Bright",
  375. index=True)
  376. timezone = CharField(
  377. max_length=64,
  378. null=True,
  379. help_text="Timezone",
  380. default="UTC+8\tAsia/Shanghai",
  381. index=True)
  382. last_login_time = DateTimeField(null=True, index=True)
  383. is_authenticated = CharField(max_length=1, null=False, default="1", index=True)
  384. is_active = CharField(max_length=1, null=False, default="1", index=True)
  385. is_anonymous = CharField(max_length=1, null=False, default="0", index=True)
  386. login_channel = CharField(null=True, help_text="from which user login", index=True)
  387. status = CharField(
  388. max_length=1,
  389. null=True,
  390. help_text="is it validate(0: wasted, 1: validate)",
  391. default="1",
  392. index=True)
  393. is_superuser = BooleanField(null=True, help_text="is root", default=False, index=True)
  394. def __str__(self):
  395. return self.email
  396. def get_id(self):
  397. jwt = Serializer(secret_key=settings.SECRET_KEY)
  398. return jwt.dumps(str(self.access_token))
  399. class Meta:
  400. db_table = "user"
  401. class Tenant(DataBaseModel):
  402. id = CharField(max_length=32, primary_key=True)
  403. name = CharField(max_length=100, null=True, help_text="Tenant name", index=True)
  404. public_key = CharField(max_length=255, null=True, index=True)
  405. llm_id = CharField(max_length=128, null=False, help_text="default llm ID", index=True)
  406. embd_id = CharField(
  407. max_length=128,
  408. null=False,
  409. help_text="default embedding model ID",
  410. index=True)
  411. asr_id = CharField(
  412. max_length=128,
  413. null=False,
  414. help_text="default ASR model ID",
  415. index=True)
  416. img2txt_id = CharField(
  417. max_length=128,
  418. null=False,
  419. help_text="default image to text model ID",
  420. index=True)
  421. rerank_id = CharField(
  422. max_length=128,
  423. null=False,
  424. help_text="default rerank model ID",
  425. index=True)
  426. tts_id = CharField(
  427. max_length=256,
  428. null=True,
  429. help_text="default tts model ID",
  430. index=True)
  431. parser_ids = CharField(
  432. max_length=256,
  433. null=False,
  434. help_text="document processors",
  435. index=True)
  436. credit = IntegerField(default=512, index=True)
  437. status = CharField(
  438. max_length=1,
  439. null=True,
  440. help_text="is it validate(0: wasted, 1: validate)",
  441. default="1",
  442. index=True)
  443. class Meta:
  444. db_table = "tenant"
  445. class UserTenant(DataBaseModel):
  446. id = CharField(max_length=32, primary_key=True)
  447. user_id = CharField(max_length=32, null=False, index=True)
  448. tenant_id = CharField(max_length=32, null=False, index=True)
  449. role = CharField(max_length=32, null=False, help_text="UserTenantRole", index=True)
  450. invited_by = CharField(max_length=32, null=False, index=True)
  451. status = CharField(
  452. max_length=1,
  453. null=True,
  454. help_text="is it validate(0: wasted, 1: validate)",
  455. default="1",
  456. index=True)
  457. class Meta:
  458. db_table = "user_tenant"
  459. class InvitationCode(DataBaseModel):
  460. id = CharField(max_length=32, primary_key=True)
  461. code = CharField(max_length=32, null=False, index=True)
  462. visit_time = DateTimeField(null=True, index=True)
  463. user_id = CharField(max_length=32, null=True, index=True)
  464. tenant_id = CharField(max_length=32, null=True, index=True)
  465. status = CharField(
  466. max_length=1,
  467. null=True,
  468. help_text="is it validate(0: wasted, 1: validate)",
  469. default="1",
  470. index=True)
  471. class Meta:
  472. db_table = "invitation_code"
  473. class LLMFactories(DataBaseModel):
  474. name = CharField(
  475. max_length=128,
  476. null=False,
  477. help_text="LLM factory name",
  478. primary_key=True)
  479. logo = TextField(null=True, help_text="llm logo base64")
  480. tags = CharField(
  481. max_length=255,
  482. null=False,
  483. help_text="LLM, Text Embedding, Image2Text, ASR",
  484. index=True)
  485. status = CharField(
  486. max_length=1,
  487. null=True,
  488. help_text="is it validate(0: wasted, 1: validate)",
  489. default="1",
  490. index=True)
  491. def __str__(self):
  492. return self.name
  493. class Meta:
  494. db_table = "llm_factories"
  495. class LLM(DataBaseModel):
  496. # LLMs dictionary
  497. llm_name = CharField(
  498. max_length=128,
  499. null=False,
  500. help_text="LLM name",
  501. index=True)
  502. model_type = CharField(
  503. max_length=128,
  504. null=False,
  505. help_text="LLM, Text Embedding, Image2Text, ASR",
  506. index=True)
  507. fid = CharField(max_length=128, null=False, help_text="LLM factory id", index=True)
  508. max_tokens = IntegerField(default=0)
  509. tags = CharField(
  510. max_length=255,
  511. null=False,
  512. help_text="LLM, Text Embedding, Image2Text, Chat, 32k...",
  513. index=True)
  514. status = CharField(
  515. max_length=1,
  516. null=True,
  517. help_text="is it validate(0: wasted, 1: validate)",
  518. default="1",
  519. index=True)
  520. def __str__(self):
  521. return self.llm_name
  522. class Meta:
  523. primary_key = CompositeKey('fid', 'llm_name')
  524. db_table = "llm"
  525. class TenantLLM(DataBaseModel):
  526. tenant_id = CharField(max_length=32, null=False, index=True)
  527. llm_factory = CharField(
  528. max_length=128,
  529. null=False,
  530. help_text="LLM factory name",
  531. index=True)
  532. model_type = CharField(
  533. max_length=128,
  534. null=True,
  535. help_text="LLM, Text Embedding, Image2Text, ASR",
  536. index=True)
  537. llm_name = CharField(
  538. max_length=128,
  539. null=True,
  540. help_text="LLM name",
  541. default="",
  542. index=True)
  543. api_key = CharField(max_length=1024, null=True, help_text="API KEY", index=True)
  544. api_base = CharField(max_length=255, null=True, help_text="API Base")
  545. max_tokens = IntegerField(default=8192, index=True)
  546. used_tokens = IntegerField(default=0, index=True)
  547. def __str__(self):
  548. return self.llm_name
  549. class Meta:
  550. db_table = "tenant_llm"
  551. primary_key = CompositeKey('tenant_id', 'llm_factory', 'llm_name')
  552. class Knowledgebase(DataBaseModel):
  553. id = CharField(max_length=32, primary_key=True)
  554. avatar = TextField(null=True, help_text="avatar base64 string")
  555. tenant_id = CharField(max_length=32, null=False, index=True)
  556. name = CharField(
  557. max_length=128,
  558. null=False,
  559. help_text="KB name",
  560. index=True)
  561. language = CharField(
  562. max_length=32,
  563. null=True,
  564. default="Chinese" if "zh_CN" in os.getenv("LANG", "") else "English",
  565. help_text="English|Chinese",
  566. index=True)
  567. description = TextField(null=True, help_text="KB description")
  568. embd_id = CharField(
  569. max_length=128,
  570. null=False,
  571. help_text="default embedding model ID",
  572. index=True)
  573. permission = CharField(
  574. max_length=16,
  575. null=False,
  576. help_text="me|team",
  577. default="me",
  578. index=True)
  579. created_by = CharField(max_length=32, null=False, index=True)
  580. doc_num = IntegerField(default=0, index=True)
  581. token_num = IntegerField(default=0, index=True)
  582. chunk_num = IntegerField(default=0, index=True)
  583. similarity_threshold = FloatField(default=0.2, index=True)
  584. vector_similarity_weight = FloatField(default=0.3, index=True)
  585. parser_id = CharField(
  586. max_length=32,
  587. null=False,
  588. help_text="default parser ID",
  589. default=ParserType.NAIVE.value,
  590. index=True)
  591. parser_config = JSONField(null=False, default={"pages": [[1, 1000000]]})
  592. pagerank = IntegerField(default=0, index=False)
  593. status = CharField(
  594. max_length=1,
  595. null=True,
  596. help_text="is it validate(0: wasted, 1: validate)",
  597. default="1",
  598. index=True)
  599. def __str__(self):
  600. return self.name
  601. class Meta:
  602. db_table = "knowledgebase"
  603. class Document(DataBaseModel):
  604. id = CharField(max_length=32, primary_key=True)
  605. thumbnail = TextField(null=True, help_text="thumbnail base64 string")
  606. kb_id = CharField(max_length=256, null=False, index=True)
  607. parser_id = CharField(
  608. max_length=32,
  609. null=False,
  610. help_text="default parser ID",
  611. index=True)
  612. parser_config = JSONField(null=False, default={"pages": [[1, 1000000]]})
  613. source_type = CharField(
  614. max_length=128,
  615. null=False,
  616. default="local",
  617. help_text="where dose this document come from",
  618. index=True)
  619. type = CharField(max_length=32, null=False, help_text="file extension",
  620. index=True)
  621. created_by = CharField(
  622. max_length=32,
  623. null=False,
  624. help_text="who created it",
  625. index=True)
  626. name = CharField(
  627. max_length=255,
  628. null=True,
  629. help_text="file name",
  630. index=True)
  631. location = CharField(
  632. max_length=255,
  633. null=True,
  634. help_text="where dose it store",
  635. index=True)
  636. size = IntegerField(default=0, index=True)
  637. token_num = IntegerField(default=0, index=True)
  638. chunk_num = IntegerField(default=0, index=True)
  639. progress = FloatField(default=0, index=True)
  640. progress_msg = TextField(
  641. null=True,
  642. help_text="process message",
  643. default="")
  644. process_begin_at = DateTimeField(null=True, index=True)
  645. process_duation = FloatField(default=0)
  646. meta_fields = JSONField(null=True, default={})
  647. run = CharField(
  648. max_length=1,
  649. null=True,
  650. help_text="start to run processing or cancel.(1: run it; 2: cancel)",
  651. default="0",
  652. index=True)
  653. status = CharField(
  654. max_length=1,
  655. null=True,
  656. help_text="is it validate(0: wasted, 1: validate)",
  657. default="1",
  658. index=True)
  659. class Meta:
  660. db_table = "document"
  661. class File(DataBaseModel):
  662. id = CharField(
  663. max_length=32,
  664. primary_key=True)
  665. parent_id = CharField(
  666. max_length=32,
  667. null=False,
  668. help_text="parent folder id",
  669. index=True)
  670. tenant_id = CharField(
  671. max_length=32,
  672. null=False,
  673. help_text="tenant id",
  674. index=True)
  675. created_by = CharField(
  676. max_length=32,
  677. null=False,
  678. help_text="who created it",
  679. index=True)
  680. name = CharField(
  681. max_length=255,
  682. null=False,
  683. help_text="file name or folder name",
  684. index=True)
  685. location = CharField(
  686. max_length=255,
  687. null=True,
  688. help_text="where dose it store",
  689. index=True)
  690. size = IntegerField(default=0, index=True)
  691. type = CharField(max_length=32, null=False, help_text="file extension", index=True)
  692. source_type = CharField(
  693. max_length=128,
  694. null=False,
  695. default="",
  696. help_text="where dose this document come from", index=True)
  697. class Meta:
  698. db_table = "file"
  699. class File2Document(DataBaseModel):
  700. id = CharField(
  701. max_length=32,
  702. primary_key=True)
  703. file_id = CharField(
  704. max_length=32,
  705. null=True,
  706. help_text="file id",
  707. index=True)
  708. document_id = CharField(
  709. max_length=32,
  710. null=True,
  711. help_text="document id",
  712. index=True)
  713. class Meta:
  714. db_table = "file2document"
  715. class Task(DataBaseModel):
  716. id = CharField(max_length=32, primary_key=True)
  717. doc_id = CharField(max_length=32, null=False, index=True)
  718. from_page = IntegerField(default=0)
  719. to_page = IntegerField(default=100000000)
  720. begin_at = DateTimeField(null=True, index=True)
  721. process_duation = FloatField(default=0)
  722. progress = FloatField(default=0, index=True)
  723. progress_msg = TextField(
  724. null=True,
  725. help_text="process message",
  726. default="")
  727. retry_count = IntegerField(default=0)
  728. digest = TextField(null=True, help_text="task digest", default="")
  729. chunk_ids = LongTextField(null=True, help_text="chunk ids", default="")
  730. class Dialog(DataBaseModel):
  731. id = CharField(max_length=32, primary_key=True)
  732. tenant_id = CharField(max_length=32, null=False, index=True)
  733. name = CharField(
  734. max_length=255,
  735. null=True,
  736. help_text="dialog application name",
  737. index=True)
  738. description = TextField(null=True, help_text="Dialog description")
  739. icon = TextField(null=True, help_text="icon base64 string")
  740. language = CharField(
  741. max_length=32,
  742. null=True,
  743. default="Chinese" if "zh_CN" in os.getenv("LANG", "") else "English",
  744. help_text="English|Chinese",
  745. index=True)
  746. llm_id = CharField(max_length=128, null=False, help_text="default llm ID")
  747. llm_setting = JSONField(null=False, default={"temperature": 0.1, "top_p": 0.3, "frequency_penalty": 0.7,
  748. "presence_penalty": 0.4, "max_tokens": 512})
  749. prompt_type = CharField(
  750. max_length=16,
  751. null=False,
  752. default="simple",
  753. help_text="simple|advanced",
  754. index=True)
  755. prompt_config = JSONField(null=False,
  756. default={"system": "", "prologue": "Hi! I'm your assistant, what can I do for you?",
  757. "parameters": [],
  758. "empty_response": "Sorry! No relevant content was found in the knowledge base!"})
  759. similarity_threshold = FloatField(default=0.2)
  760. vector_similarity_weight = FloatField(default=0.3)
  761. top_n = IntegerField(default=6)
  762. top_k = IntegerField(default=1024)
  763. do_refer = CharField(
  764. max_length=1,
  765. null=False,
  766. default="1",
  767. help_text="it needs to insert reference index into answer or not")
  768. rerank_id = CharField(
  769. max_length=128,
  770. null=False,
  771. help_text="default rerank model ID")
  772. kb_ids = JSONField(null=False, default=[])
  773. status = CharField(
  774. max_length=1,
  775. null=True,
  776. help_text="is it validate(0: wasted, 1: validate)",
  777. default="1",
  778. index=True)
  779. class Meta:
  780. db_table = "dialog"
  781. class Conversation(DataBaseModel):
  782. id = CharField(max_length=32, primary_key=True)
  783. dialog_id = CharField(max_length=32, null=False, index=True)
  784. name = CharField(max_length=255, null=True, help_text="converastion name", index=True)
  785. message = JSONField(null=True)
  786. reference = JSONField(null=True, default=[])
  787. user_id = CharField(max_length=255, null=True, help_text="user_id", index=True)
  788. class Meta:
  789. db_table = "conversation"
  790. class APIToken(DataBaseModel):
  791. tenant_id = CharField(max_length=32, null=False, index=True)
  792. token = CharField(max_length=255, null=False, index=True)
  793. dialog_id = CharField(max_length=32, null=False, index=True)
  794. source = CharField(max_length=16, null=True, help_text="none|agent|dialog", index=True)
  795. beta = CharField(max_length=255, null=True, index=True)
  796. class Meta:
  797. db_table = "api_token"
  798. primary_key = CompositeKey('tenant_id', 'token')
  799. class API4Conversation(DataBaseModel):
  800. id = CharField(max_length=32, primary_key=True)
  801. dialog_id = CharField(max_length=32, null=False, index=True)
  802. user_id = CharField(max_length=255, null=False, help_text="user_id", index=True)
  803. message = JSONField(null=True)
  804. reference = JSONField(null=True, default=[])
  805. tokens = IntegerField(default=0)
  806. source = CharField(max_length=16, null=True, help_text="none|agent|dialog", index=True)
  807. dsl = JSONField(null=True, default={})
  808. duration = FloatField(default=0, index=True)
  809. round = IntegerField(default=0, index=True)
  810. thumb_up = IntegerField(default=0, index=True)
  811. class Meta:
  812. db_table = "api_4_conversation"
  813. class UserCanvas(DataBaseModel):
  814. id = CharField(max_length=32, primary_key=True)
  815. avatar = TextField(null=True, help_text="avatar base64 string")
  816. user_id = CharField(max_length=255, null=False, help_text="user_id", index=True)
  817. title = CharField(max_length=255, null=True, help_text="Canvas title")
  818. description = TextField(null=True, help_text="Canvas description")
  819. canvas_type = CharField(max_length=32, null=True, help_text="Canvas type", index=True)
  820. dsl = JSONField(null=True, default={})
  821. class Meta:
  822. db_table = "user_canvas"
  823. class CanvasTemplate(DataBaseModel):
  824. id = CharField(max_length=32, primary_key=True)
  825. avatar = TextField(null=True, help_text="avatar base64 string")
  826. title = CharField(max_length=255, null=True, help_text="Canvas title")
  827. description = TextField(null=True, help_text="Canvas description")
  828. canvas_type = CharField(max_length=32, null=True, help_text="Canvas type", index=True)
  829. dsl = JSONField(null=True, default={})
  830. class Meta:
  831. db_table = "canvas_template"
  832. def migrate_db():
  833. with DB.transaction():
  834. migrator = DatabaseMigrator[settings.DATABASE_TYPE.upper()].value(DB)
  835. try:
  836. migrate(
  837. migrator.add_column('file', 'source_type', CharField(max_length=128, null=False, default="",
  838. help_text="where dose this document come from",
  839. index=True))
  840. )
  841. except Exception:
  842. pass
  843. try:
  844. migrate(
  845. migrator.add_column('tenant', 'rerank_id',
  846. CharField(max_length=128, null=False, default="BAAI/bge-reranker-v2-m3",
  847. help_text="default rerank model ID"))
  848. )
  849. except Exception:
  850. pass
  851. try:
  852. migrate(
  853. migrator.add_column('dialog', 'rerank_id', CharField(max_length=128, null=False, default="",
  854. help_text="default rerank model ID"))
  855. )
  856. except Exception:
  857. pass
  858. try:
  859. migrate(
  860. migrator.add_column('dialog', 'top_k', IntegerField(default=1024))
  861. )
  862. except Exception:
  863. pass
  864. try:
  865. migrate(
  866. migrator.alter_column_type('tenant_llm', 'api_key',
  867. CharField(max_length=1024, null=True, help_text="API KEY", index=True))
  868. )
  869. except Exception:
  870. pass
  871. try:
  872. migrate(
  873. migrator.add_column('api_token', 'source',
  874. CharField(max_length=16, null=True, help_text="none|agent|dialog", index=True))
  875. )
  876. except Exception:
  877. pass
  878. try:
  879. migrate(
  880. migrator.add_column("tenant", "tts_id",
  881. CharField(max_length=256, null=True, help_text="default tts model ID", index=True))
  882. )
  883. except Exception:
  884. pass
  885. try:
  886. migrate(
  887. migrator.add_column('api_4_conversation', 'source',
  888. CharField(max_length=16, null=True, help_text="none|agent|dialog", index=True))
  889. )
  890. except Exception:
  891. pass
  892. try:
  893. DB.execute_sql('ALTER TABLE llm DROP PRIMARY KEY;')
  894. DB.execute_sql('ALTER TABLE llm ADD PRIMARY KEY (llm_name,fid);')
  895. except Exception:
  896. pass
  897. try:
  898. migrate(
  899. migrator.add_column('task', 'retry_count', IntegerField(default=0))
  900. )
  901. except Exception:
  902. pass
  903. try:
  904. migrate(
  905. migrator.alter_column_type('api_token', 'dialog_id',
  906. CharField(max_length=32, null=True, index=True))
  907. )
  908. except Exception:
  909. pass
  910. try:
  911. migrate(
  912. migrator.add_column("tenant_llm", "max_tokens", IntegerField(default=8192, index=True))
  913. )
  914. except Exception:
  915. pass
  916. try:
  917. migrate(
  918. migrator.add_column("api_4_conversation", "dsl", JSONField(null=True, default={}))
  919. )
  920. except Exception:
  921. pass
  922. try:
  923. migrate(
  924. migrator.add_column("knowledgebase", "pagerank", IntegerField(default=0, index=False))
  925. )
  926. except Exception:
  927. pass
  928. try:
  929. migrate(
  930. migrator.add_column("api_token", "beta", CharField(max_length=255, null=True, index=True))
  931. )
  932. except Exception:
  933. pass
  934. try:
  935. migrate(
  936. migrator.add_column("task", "digest", TextField(null=True, help_text="task digest", default=""))
  937. )
  938. except Exception:
  939. pass
  940. try:
  941. migrate(
  942. migrator.add_column("task", "chunk_ids", LongTextField(null=True, help_text="chunk ids", default=""))
  943. )
  944. except Exception:
  945. pass
  946. try:
  947. migrate(
  948. migrator.add_column("conversation", "user_id",
  949. CharField(max_length=255, null=True, help_text="user_id", index=True))
  950. )
  951. except Exception:
  952. pass
  953. try:
  954. migrate(
  955. migrator.add_column("document", "meta_fields",
  956. JSONField(null=True, default={}))
  957. )
  958. except Exception:
  959. pass