Вы не можете выбрать более 25 тем Темы должны начинаться с буквы или цифры, могут содержать дефисы(-) и должны содержать не более 35 символов.

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import logging
  17. import inspect
  18. import os
  19. import sys
  20. import operator
  21. from enum import Enum
  22. from functools import wraps
  23. from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
  24. from flask_login import UserMixin
  25. from playhouse.migrate import MySQLMigrator, PostgresqlMigrator, migrate
  26. from peewee import (
  27. BigIntegerField, BooleanField, CharField,
  28. CompositeKey, IntegerField, TextField, FloatField, DateTimeField,
  29. Field, Model, Metadata
  30. )
  31. from playhouse.pool import PooledMySQLDatabase, PooledPostgresqlDatabase
  32. from api.db import SerializedType, ParserType
  33. from api import settings
  34. from api import utils
  35. def singleton(cls, *args, **kw):
  36. instances = {}
  37. def _singleton():
  38. key = str(cls) + str(os.getpid())
  39. if key not in instances:
  40. instances[key] = cls(*args, **kw)
  41. return instances[key]
  42. return _singleton
  43. CONTINUOUS_FIELD_TYPE = {IntegerField, FloatField, DateTimeField}
  44. AUTO_DATE_TIMESTAMP_FIELD_PREFIX = {
  45. "create",
  46. "start",
  47. "end",
  48. "update",
  49. "read_access",
  50. "write_access"}
  51. class TextFieldType(Enum):
  52. MYSQL = 'LONGTEXT'
  53. POSTGRES = 'TEXT'
  54. class LongTextField(TextField):
  55. field_type = TextFieldType[settings.DATABASE_TYPE.upper()].value
  56. class JSONField(LongTextField):
  57. default_value = {}
  58. def __init__(self, object_hook=None, object_pairs_hook=None, **kwargs):
  59. self._object_hook = object_hook
  60. self._object_pairs_hook = object_pairs_hook
  61. super().__init__(**kwargs)
  62. def db_value(self, value):
  63. if value is None:
  64. value = self.default_value
  65. return utils.json_dumps(value)
  66. def python_value(self, value):
  67. if not value:
  68. return self.default_value
  69. return utils.json_loads(
  70. value, object_hook=self._object_hook, object_pairs_hook=self._object_pairs_hook)
  71. class ListField(JSONField):
  72. default_value = []
  73. class SerializedField(LongTextField):
  74. def __init__(self, serialized_type=SerializedType.PICKLE,
  75. object_hook=None, object_pairs_hook=None, **kwargs):
  76. self._serialized_type = serialized_type
  77. self._object_hook = object_hook
  78. self._object_pairs_hook = object_pairs_hook
  79. super().__init__(**kwargs)
  80. def db_value(self, value):
  81. if self._serialized_type == SerializedType.PICKLE:
  82. return utils.serialize_b64(value, to_str=True)
  83. elif self._serialized_type == SerializedType.JSON:
  84. if value is None:
  85. return None
  86. return utils.json_dumps(value, with_type=True)
  87. else:
  88. raise ValueError(
  89. f"the serialized type {self._serialized_type} is not supported")
  90. def python_value(self, value):
  91. if self._serialized_type == SerializedType.PICKLE:
  92. return utils.deserialize_b64(value)
  93. elif self._serialized_type == SerializedType.JSON:
  94. if value is None:
  95. return {}
  96. return utils.json_loads(
  97. value, object_hook=self._object_hook, object_pairs_hook=self._object_pairs_hook)
  98. else:
  99. raise ValueError(
  100. f"the serialized type {self._serialized_type} is not supported")
  101. def is_continuous_field(cls: type) -> bool:
  102. if cls in CONTINUOUS_FIELD_TYPE:
  103. return True
  104. for p in cls.__bases__:
  105. if p in CONTINUOUS_FIELD_TYPE:
  106. return True
  107. elif p is not Field and p is not object:
  108. if is_continuous_field(p):
  109. return True
  110. else:
  111. return False
  112. def auto_date_timestamp_field():
  113. return {f"{f}_time" for f in AUTO_DATE_TIMESTAMP_FIELD_PREFIX}
  114. def auto_date_timestamp_db_field():
  115. return {f"f_{f}_time" for f in AUTO_DATE_TIMESTAMP_FIELD_PREFIX}
  116. def remove_field_name_prefix(field_name):
  117. return field_name[2:] if field_name.startswith('f_') else field_name
  118. class BaseModel(Model):
  119. create_time = BigIntegerField(null=True, index=True)
  120. create_date = DateTimeField(null=True, index=True)
  121. update_time = BigIntegerField(null=True, index=True)
  122. update_date = DateTimeField(null=True, index=True)
  123. def to_json(self):
  124. # This function is obsolete
  125. return self.to_dict()
  126. def to_dict(self):
  127. return self.__dict__['__data__']
  128. def to_human_model_dict(self, only_primary_with: list | None = None):
  129. model_dict = self.__dict__['__data__']
  130. if not only_primary_with:
  131. return {remove_field_name_prefix(
  132. k): v for k, v in model_dict.items()}
  133. human_model_dict = {}
  134. for k in self._meta.primary_key.field_names:
  135. human_model_dict[remove_field_name_prefix(k)] = model_dict[k]
  136. for k in only_primary_with:
  137. human_model_dict[k] = model_dict[f'f_{k}']
  138. return human_model_dict
  139. @property
  140. def meta(self) -> Metadata:
  141. return self._meta
  142. @classmethod
  143. def get_primary_keys_name(cls):
  144. return cls._meta.primary_key.field_names if isinstance(cls._meta.primary_key, CompositeKey) else [
  145. cls._meta.primary_key.name]
  146. @classmethod
  147. def getter_by(cls, attr):
  148. return operator.attrgetter(attr)(cls)
  149. @classmethod
  150. def query(cls, reverse=None, order_by=None, **kwargs):
  151. filters = []
  152. for f_n, f_v in kwargs.items():
  153. attr_name = '%s' % f_n
  154. if not hasattr(cls, attr_name) or f_v is None:
  155. continue
  156. if type(f_v) in {list, set}:
  157. f_v = list(f_v)
  158. if is_continuous_field(type(getattr(cls, attr_name))):
  159. if len(f_v) == 2:
  160. for i, v in enumerate(f_v):
  161. if isinstance(
  162. v, str) and f_n in auto_date_timestamp_field():
  163. # time type: %Y-%m-%d %H:%M:%S
  164. f_v[i] = utils.date_string_to_timestamp(v)
  165. lt_value = f_v[0]
  166. gt_value = f_v[1]
  167. if lt_value is not None and gt_value is not None:
  168. filters.append(
  169. cls.getter_by(attr_name).between(
  170. lt_value, gt_value))
  171. elif lt_value is not None:
  172. filters.append(
  173. operator.attrgetter(attr_name)(cls) >= lt_value)
  174. elif gt_value is not None:
  175. filters.append(
  176. operator.attrgetter(attr_name)(cls) <= gt_value)
  177. else:
  178. filters.append(operator.attrgetter(attr_name)(cls) << f_v)
  179. else:
  180. filters.append(operator.attrgetter(attr_name)(cls) == f_v)
  181. if filters:
  182. query_records = cls.select().where(*filters)
  183. if reverse is not None:
  184. if not order_by or not hasattr(cls, f"{order_by}"):
  185. order_by = "create_time"
  186. if reverse is True:
  187. query_records = query_records.order_by(
  188. cls.getter_by(f"{order_by}").desc())
  189. elif reverse is False:
  190. query_records = query_records.order_by(
  191. cls.getter_by(f"{order_by}").asc())
  192. return [query_record for query_record in query_records]
  193. else:
  194. return []
  195. @classmethod
  196. def insert(cls, __data=None, **insert):
  197. if isinstance(__data, dict) and __data:
  198. __data[cls._meta.combined["create_time"]
  199. ] = utils.current_timestamp()
  200. if insert:
  201. insert["create_time"] = utils.current_timestamp()
  202. return super().insert(__data, **insert)
  203. # update and insert will call this method
  204. @classmethod
  205. def _normalize_data(cls, data, kwargs):
  206. normalized = super()._normalize_data(data, kwargs)
  207. if not normalized:
  208. return {}
  209. normalized[cls._meta.combined["update_time"]
  210. ] = utils.current_timestamp()
  211. for f_n in AUTO_DATE_TIMESTAMP_FIELD_PREFIX:
  212. if {f"{f_n}_time", f"{f_n}_date"}.issubset(cls._meta.combined.keys()) and \
  213. cls._meta.combined[f"{f_n}_time"] in normalized and \
  214. normalized[cls._meta.combined[f"{f_n}_time"]] is not None:
  215. normalized[cls._meta.combined[f"{f_n}_date"]] = utils.timestamp_to_date(
  216. normalized[cls._meta.combined[f"{f_n}_time"]])
  217. return normalized
  218. class JsonSerializedField(SerializedField):
  219. def __init__(self, object_hook=utils.from_dict_hook,
  220. object_pairs_hook=None, **kwargs):
  221. super(JsonSerializedField, self).__init__(serialized_type=SerializedType.JSON, object_hook=object_hook,
  222. object_pairs_hook=object_pairs_hook, **kwargs)
  223. class PooledDatabase(Enum):
  224. MYSQL = PooledMySQLDatabase
  225. POSTGRES = PooledPostgresqlDatabase
  226. class DatabaseMigrator(Enum):
  227. MYSQL = MySQLMigrator
  228. POSTGRES = PostgresqlMigrator
  229. @singleton
  230. class BaseDataBase:
  231. def __init__(self):
  232. database_config = settings.DATABASE.copy()
  233. db_name = database_config.pop("name")
  234. self.database_connection = PooledDatabase[settings.DATABASE_TYPE.upper()].value(db_name, **database_config)
  235. logging.info('init database on cluster mode successfully')
  236. class PostgresDatabaseLock:
  237. def __init__(self, lock_name, timeout=10, db=None):
  238. self.lock_name = lock_name
  239. self.timeout = int(timeout)
  240. self.db = db if db else DB
  241. def lock(self):
  242. cursor = self.db.execute_sql("SELECT pg_try_advisory_lock(%s)", self.timeout)
  243. ret = cursor.fetchone()
  244. if ret[0] == 0:
  245. raise Exception(f'acquire postgres lock {self.lock_name} timeout')
  246. elif ret[0] == 1:
  247. return True
  248. else:
  249. raise Exception(f'failed to acquire lock {self.lock_name}')
  250. def unlock(self):
  251. cursor = self.db.execute_sql("SELECT pg_advisory_unlock(%s)", self.timeout)
  252. ret = cursor.fetchone()
  253. if ret[0] == 0:
  254. raise Exception(
  255. f'postgres lock {self.lock_name} was not established by this thread')
  256. elif ret[0] == 1:
  257. return True
  258. else:
  259. raise Exception(f'postgres lock {self.lock_name} does not exist')
  260. def __enter__(self):
  261. if isinstance(self.db, PostgresDatabaseLock):
  262. self.lock()
  263. return self
  264. def __exit__(self, exc_type, exc_val, exc_tb):
  265. if isinstance(self.db, PostgresDatabaseLock):
  266. self.unlock()
  267. def __call__(self, func):
  268. @wraps(func)
  269. def magic(*args, **kwargs):
  270. with self:
  271. return func(*args, **kwargs)
  272. return magic
  273. class MysqlDatabaseLock:
  274. def __init__(self, lock_name, timeout=10, db=None):
  275. self.lock_name = lock_name
  276. self.timeout = int(timeout)
  277. self.db = db if db else DB
  278. def lock(self):
  279. # SQL parameters only support %s format placeholders
  280. cursor = self.db.execute_sql(
  281. "SELECT GET_LOCK(%s, %s)", (self.lock_name, self.timeout))
  282. ret = cursor.fetchone()
  283. if ret[0] == 0:
  284. raise Exception(f'acquire mysql lock {self.lock_name} timeout')
  285. elif ret[0] == 1:
  286. return True
  287. else:
  288. raise Exception(f'failed to acquire lock {self.lock_name}')
  289. def unlock(self):
  290. cursor = self.db.execute_sql(
  291. "SELECT RELEASE_LOCK(%s)", (self.lock_name,))
  292. ret = cursor.fetchone()
  293. if ret[0] == 0:
  294. raise Exception(
  295. f'mysql lock {self.lock_name} was not established by this thread')
  296. elif ret[0] == 1:
  297. return True
  298. else:
  299. raise Exception(f'mysql lock {self.lock_name} does not exist')
  300. def __enter__(self):
  301. if isinstance(self.db, PooledMySQLDatabase):
  302. self.lock()
  303. return self
  304. def __exit__(self, exc_type, exc_val, exc_tb):
  305. if isinstance(self.db, PooledMySQLDatabase):
  306. self.unlock()
  307. def __call__(self, func):
  308. @wraps(func)
  309. def magic(*args, **kwargs):
  310. with self:
  311. return func(*args, **kwargs)
  312. return magic
  313. class DatabaseLock(Enum):
  314. MYSQL = MysqlDatabaseLock
  315. POSTGRES = PostgresDatabaseLock
  316. DB = BaseDataBase().database_connection
  317. DB.lock = DatabaseLock[settings.DATABASE_TYPE.upper()].value
  318. def close_connection():
  319. try:
  320. if DB:
  321. DB.close_stale(age=30)
  322. except Exception as e:
  323. logging.exception(e)
  324. class DataBaseModel(BaseModel):
  325. class Meta:
  326. database = DB
  327. @DB.connection_context()
  328. def init_database_tables(alter_fields=[]):
  329. members = inspect.getmembers(sys.modules[__name__], inspect.isclass)
  330. table_objs = []
  331. create_failed_list = []
  332. for name, obj in members:
  333. if obj != DataBaseModel and issubclass(obj, DataBaseModel):
  334. table_objs.append(obj)
  335. logging.debug(f"start create table {obj.__name__}")
  336. try:
  337. obj.create_table()
  338. logging.debug(f"create table success: {obj.__name__}")
  339. except Exception as e:
  340. logging.exception(e)
  341. create_failed_list.append(obj.__name__)
  342. if create_failed_list:
  343. logging.error(f"create tables failed: {create_failed_list}")
  344. raise Exception(f"create tables failed: {create_failed_list}")
  345. migrate_db()
  346. def fill_db_model_object(model_object, human_model_dict):
  347. for k, v in human_model_dict.items():
  348. attr_name = '%s' % k
  349. if hasattr(model_object.__class__, attr_name):
  350. setattr(model_object, attr_name, v)
  351. return model_object
  352. class User(DataBaseModel, UserMixin):
  353. id = CharField(max_length=32, primary_key=True)
  354. access_token = CharField(max_length=255, null=True, index=True)
  355. nickname = CharField(max_length=100, null=False, help_text="nicky name", index=True)
  356. password = CharField(max_length=255, null=True, help_text="password", index=True)
  357. email = CharField(
  358. max_length=255,
  359. null=False,
  360. help_text="email",
  361. index=True)
  362. avatar = TextField(null=True, help_text="avatar base64 string")
  363. language = CharField(
  364. max_length=32,
  365. null=True,
  366. help_text="English|Chinese",
  367. default="Chinese" if "zh_CN" in os.getenv("LANG", "") else "English",
  368. index=True)
  369. color_schema = CharField(
  370. max_length=32,
  371. null=True,
  372. help_text="Bright|Dark",
  373. default="Bright",
  374. index=True)
  375. timezone = CharField(
  376. max_length=64,
  377. null=True,
  378. help_text="Timezone",
  379. default="UTC+8\tAsia/Shanghai",
  380. index=True)
  381. last_login_time = DateTimeField(null=True, index=True)
  382. is_authenticated = CharField(max_length=1, null=False, default="1", index=True)
  383. is_active = CharField(max_length=1, null=False, default="1", index=True)
  384. is_anonymous = CharField(max_length=1, null=False, default="0", index=True)
  385. login_channel = CharField(null=True, help_text="from which user login", index=True)
  386. status = CharField(
  387. max_length=1,
  388. null=True,
  389. help_text="is it validate(0: wasted, 1: validate)",
  390. default="1",
  391. index=True)
  392. is_superuser = BooleanField(null=True, help_text="is root", default=False, index=True)
  393. def __str__(self):
  394. return self.email
  395. def get_id(self):
  396. jwt = Serializer(secret_key=settings.SECRET_KEY)
  397. return jwt.dumps(str(self.access_token))
  398. class Meta:
  399. db_table = "user"
  400. class Tenant(DataBaseModel):
  401. id = CharField(max_length=32, primary_key=True)
  402. name = CharField(max_length=100, null=True, help_text="Tenant name", index=True)
  403. public_key = CharField(max_length=255, null=True, index=True)
  404. llm_id = CharField(max_length=128, null=False, help_text="default llm ID", index=True)
  405. embd_id = CharField(
  406. max_length=128,
  407. null=False,
  408. help_text="default embedding model ID",
  409. index=True)
  410. asr_id = CharField(
  411. max_length=128,
  412. null=False,
  413. help_text="default ASR model ID",
  414. index=True)
  415. img2txt_id = CharField(
  416. max_length=128,
  417. null=False,
  418. help_text="default image to text model ID",
  419. index=True)
  420. rerank_id = CharField(
  421. max_length=128,
  422. null=False,
  423. help_text="default rerank model ID",
  424. index=True)
  425. tts_id = CharField(
  426. max_length=256,
  427. null=True,
  428. help_text="default tts model ID",
  429. index=True)
  430. parser_ids = CharField(
  431. max_length=256,
  432. null=False,
  433. help_text="document processors",
  434. index=True)
  435. credit = IntegerField(default=512, index=True)
  436. status = CharField(
  437. max_length=1,
  438. null=True,
  439. help_text="is it validate(0: wasted, 1: validate)",
  440. default="1",
  441. index=True)
  442. class Meta:
  443. db_table = "tenant"
  444. class UserTenant(DataBaseModel):
  445. id = CharField(max_length=32, primary_key=True)
  446. user_id = CharField(max_length=32, null=False, index=True)
  447. tenant_id = CharField(max_length=32, null=False, index=True)
  448. role = CharField(max_length=32, null=False, help_text="UserTenantRole", index=True)
  449. invited_by = CharField(max_length=32, null=False, index=True)
  450. status = CharField(
  451. max_length=1,
  452. null=True,
  453. help_text="is it validate(0: wasted, 1: validate)",
  454. default="1",
  455. index=True)
  456. class Meta:
  457. db_table = "user_tenant"
  458. class InvitationCode(DataBaseModel):
  459. id = CharField(max_length=32, primary_key=True)
  460. code = CharField(max_length=32, null=False, index=True)
  461. visit_time = DateTimeField(null=True, index=True)
  462. user_id = CharField(max_length=32, null=True, index=True)
  463. tenant_id = CharField(max_length=32, null=True, index=True)
  464. status = CharField(
  465. max_length=1,
  466. null=True,
  467. help_text="is it validate(0: wasted, 1: validate)",
  468. default="1",
  469. index=True)
  470. class Meta:
  471. db_table = "invitation_code"
  472. class LLMFactories(DataBaseModel):
  473. name = CharField(
  474. max_length=128,
  475. null=False,
  476. help_text="LLM factory name",
  477. primary_key=True)
  478. logo = TextField(null=True, help_text="llm logo base64")
  479. tags = CharField(
  480. max_length=255,
  481. null=False,
  482. help_text="LLM, Text Embedding, Image2Text, ASR",
  483. index=True)
  484. status = CharField(
  485. max_length=1,
  486. null=True,
  487. help_text="is it validate(0: wasted, 1: validate)",
  488. default="1",
  489. index=True)
  490. def __str__(self):
  491. return self.name
  492. class Meta:
  493. db_table = "llm_factories"
  494. class LLM(DataBaseModel):
  495. # LLMs dictionary
  496. llm_name = CharField(
  497. max_length=128,
  498. null=False,
  499. help_text="LLM name",
  500. index=True)
  501. model_type = CharField(
  502. max_length=128,
  503. null=False,
  504. help_text="LLM, Text Embedding, Image2Text, ASR",
  505. index=True)
  506. fid = CharField(max_length=128, null=False, help_text="LLM factory id", index=True)
  507. max_tokens = IntegerField(default=0)
  508. tags = CharField(
  509. max_length=255,
  510. null=False,
  511. help_text="LLM, Text Embedding, Image2Text, Chat, 32k...",
  512. index=True)
  513. status = CharField(
  514. max_length=1,
  515. null=True,
  516. help_text="is it validate(0: wasted, 1: validate)",
  517. default="1",
  518. index=True)
  519. def __str__(self):
  520. return self.llm_name
  521. class Meta:
  522. primary_key = CompositeKey('fid', 'llm_name')
  523. db_table = "llm"
  524. class TenantLLM(DataBaseModel):
  525. tenant_id = CharField(max_length=32, null=False, index=True)
  526. llm_factory = CharField(
  527. max_length=128,
  528. null=False,
  529. help_text="LLM factory name",
  530. index=True)
  531. model_type = CharField(
  532. max_length=128,
  533. null=True,
  534. help_text="LLM, Text Embedding, Image2Text, ASR",
  535. index=True)
  536. llm_name = CharField(
  537. max_length=128,
  538. null=True,
  539. help_text="LLM name",
  540. default="",
  541. index=True)
  542. api_key = CharField(max_length=1024, null=True, help_text="API KEY", index=True)
  543. api_base = CharField(max_length=255, null=True, help_text="API Base")
  544. used_tokens = IntegerField(default=0, index=True)
  545. def __str__(self):
  546. return self.llm_name
  547. class Meta:
  548. db_table = "tenant_llm"
  549. primary_key = CompositeKey('tenant_id', 'llm_factory', 'llm_name')
  550. class Knowledgebase(DataBaseModel):
  551. id = CharField(max_length=32, primary_key=True)
  552. avatar = TextField(null=True, help_text="avatar base64 string")
  553. tenant_id = CharField(max_length=32, null=False, index=True)
  554. name = CharField(
  555. max_length=128,
  556. null=False,
  557. help_text="KB name",
  558. index=True)
  559. language = CharField(
  560. max_length=32,
  561. null=True,
  562. default="Chinese" if "zh_CN" in os.getenv("LANG", "") else "English",
  563. help_text="English|Chinese",
  564. index=True)
  565. description = TextField(null=True, help_text="KB description")
  566. embd_id = CharField(
  567. max_length=128,
  568. null=False,
  569. help_text="default embedding model ID",
  570. index=True)
  571. permission = CharField(
  572. max_length=16,
  573. null=False,
  574. help_text="me|team",
  575. default="me",
  576. index=True)
  577. created_by = CharField(max_length=32, null=False, index=True)
  578. doc_num = IntegerField(default=0, index=True)
  579. token_num = IntegerField(default=0, index=True)
  580. chunk_num = IntegerField(default=0, index=True)
  581. similarity_threshold = FloatField(default=0.2, index=True)
  582. vector_similarity_weight = FloatField(default=0.3, index=True)
  583. parser_id = CharField(
  584. max_length=32,
  585. null=False,
  586. help_text="default parser ID",
  587. default=ParserType.NAIVE.value,
  588. index=True)
  589. parser_config = JSONField(null=False, default={"pages": [[1, 1000000]]})
  590. status = CharField(
  591. max_length=1,
  592. null=True,
  593. help_text="is it validate(0: wasted, 1: validate)",
  594. default="1",
  595. index=True)
  596. def __str__(self):
  597. return self.name
  598. class Meta:
  599. db_table = "knowledgebase"
  600. class Document(DataBaseModel):
  601. id = CharField(max_length=32, primary_key=True)
  602. thumbnail = TextField(null=True, help_text="thumbnail base64 string")
  603. kb_id = CharField(max_length=256, null=False, index=True)
  604. parser_id = CharField(
  605. max_length=32,
  606. null=False,
  607. help_text="default parser ID",
  608. index=True)
  609. parser_config = JSONField(null=False, default={"pages": [[1, 1000000]]})
  610. source_type = CharField(
  611. max_length=128,
  612. null=False,
  613. default="local",
  614. help_text="where dose this document come from",
  615. index=True)
  616. type = CharField(max_length=32, null=False, help_text="file extension",
  617. index=True)
  618. created_by = CharField(
  619. max_length=32,
  620. null=False,
  621. help_text="who created it",
  622. index=True)
  623. name = CharField(
  624. max_length=255,
  625. null=True,
  626. help_text="file name",
  627. index=True)
  628. location = CharField(
  629. max_length=255,
  630. null=True,
  631. help_text="where dose it store",
  632. index=True)
  633. size = IntegerField(default=0, index=True)
  634. token_num = IntegerField(default=0, index=True)
  635. chunk_num = IntegerField(default=0, index=True)
  636. progress = FloatField(default=0, index=True)
  637. progress_msg = TextField(
  638. null=True,
  639. help_text="process message",
  640. default="")
  641. process_begin_at = DateTimeField(null=True, index=True)
  642. process_duation = FloatField(default=0)
  643. run = CharField(
  644. max_length=1,
  645. null=True,
  646. help_text="start to run processing or cancel.(1: run it; 2: cancel)",
  647. default="0",
  648. index=True)
  649. status = CharField(
  650. max_length=1,
  651. null=True,
  652. help_text="is it validate(0: wasted, 1: validate)",
  653. default="1",
  654. index=True)
  655. class Meta:
  656. db_table = "document"
  657. class File(DataBaseModel):
  658. id = CharField(
  659. max_length=32,
  660. primary_key=True)
  661. parent_id = CharField(
  662. max_length=32,
  663. null=False,
  664. help_text="parent folder id",
  665. index=True)
  666. tenant_id = CharField(
  667. max_length=32,
  668. null=False,
  669. help_text="tenant id",
  670. index=True)
  671. created_by = CharField(
  672. max_length=32,
  673. null=False,
  674. help_text="who created it",
  675. index=True)
  676. name = CharField(
  677. max_length=255,
  678. null=False,
  679. help_text="file name or folder name",
  680. index=True)
  681. location = CharField(
  682. max_length=255,
  683. null=True,
  684. help_text="where dose it store",
  685. index=True)
  686. size = IntegerField(default=0, index=True)
  687. type = CharField(max_length=32, null=False, help_text="file extension", index=True)
  688. source_type = CharField(
  689. max_length=128,
  690. null=False,
  691. default="",
  692. help_text="where dose this document come from", index=True)
  693. class Meta:
  694. db_table = "file"
  695. class File2Document(DataBaseModel):
  696. id = CharField(
  697. max_length=32,
  698. primary_key=True)
  699. file_id = CharField(
  700. max_length=32,
  701. null=True,
  702. help_text="file id",
  703. index=True)
  704. document_id = CharField(
  705. max_length=32,
  706. null=True,
  707. help_text="document id",
  708. index=True)
  709. class Meta:
  710. db_table = "file2document"
  711. class Task(DataBaseModel):
  712. id = CharField(max_length=32, primary_key=True)
  713. doc_id = CharField(max_length=32, null=False, index=True)
  714. from_page = IntegerField(default=0)
  715. to_page = IntegerField(default=100000000)
  716. begin_at = DateTimeField(null=True, index=True)
  717. process_duation = FloatField(default=0)
  718. progress = FloatField(default=0, index=True)
  719. progress_msg = TextField(
  720. null=True,
  721. help_text="process message",
  722. default="")
  723. retry_count = IntegerField(default=0)
  724. class Dialog(DataBaseModel):
  725. id = CharField(max_length=32, primary_key=True)
  726. tenant_id = CharField(max_length=32, null=False, index=True)
  727. name = CharField(
  728. max_length=255,
  729. null=True,
  730. help_text="dialog application name",
  731. index=True)
  732. description = TextField(null=True, help_text="Dialog description")
  733. icon = TextField(null=True, help_text="icon base64 string")
  734. language = CharField(
  735. max_length=32,
  736. null=True,
  737. default="Chinese" if "zh_CN" in os.getenv("LANG", "") else "English",
  738. help_text="English|Chinese",
  739. index=True)
  740. llm_id = CharField(max_length=128, null=False, help_text="default llm ID")
  741. llm_setting = JSONField(null=False, default={"temperature": 0.1, "top_p": 0.3, "frequency_penalty": 0.7,
  742. "presence_penalty": 0.4, "max_tokens": 512})
  743. prompt_type = CharField(
  744. max_length=16,
  745. null=False,
  746. default="simple",
  747. help_text="simple|advanced",
  748. index=True)
  749. prompt_config = JSONField(null=False, default={"system": "", "prologue": "Hi! I'm your assistant, what can I do for you?",
  750. "parameters": [], "empty_response": "Sorry! No relevant content was found in the knowledge base!"})
  751. similarity_threshold = FloatField(default=0.2)
  752. vector_similarity_weight = FloatField(default=0.3)
  753. top_n = IntegerField(default=6)
  754. top_k = IntegerField(default=1024)
  755. do_refer = CharField(
  756. max_length=1,
  757. null=False,
  758. default="1",
  759. help_text="it needs to insert reference index into answer or not")
  760. rerank_id = CharField(
  761. max_length=128,
  762. null=False,
  763. help_text="default rerank model ID")
  764. kb_ids = JSONField(null=False, default=[])
  765. status = CharField(
  766. max_length=1,
  767. null=True,
  768. help_text="is it validate(0: wasted, 1: validate)",
  769. default="1",
  770. index=True)
  771. class Meta:
  772. db_table = "dialog"
  773. class Conversation(DataBaseModel):
  774. id = CharField(max_length=32, primary_key=True)
  775. dialog_id = CharField(max_length=32, null=False, index=True)
  776. name = CharField(max_length=255, null=True, help_text="converastion name", index=True)
  777. message = JSONField(null=True)
  778. reference = JSONField(null=True, default=[])
  779. class Meta:
  780. db_table = "conversation"
  781. class APIToken(DataBaseModel):
  782. tenant_id = CharField(max_length=32, null=False, index=True)
  783. token = CharField(max_length=255, null=False, index=True)
  784. dialog_id = CharField(max_length=32, null=False, index=True)
  785. source = CharField(max_length=16, null=True, help_text="none|agent|dialog", index=True)
  786. class Meta:
  787. db_table = "api_token"
  788. primary_key = CompositeKey('tenant_id', 'token')
  789. class API4Conversation(DataBaseModel):
  790. id = CharField(max_length=32, primary_key=True)
  791. dialog_id = CharField(max_length=32, null=False, index=True)
  792. user_id = CharField(max_length=255, null=False, help_text="user_id", index=True)
  793. message = JSONField(null=True)
  794. reference = JSONField(null=True, default=[])
  795. tokens = IntegerField(default=0)
  796. source = CharField(max_length=16, null=True, help_text="none|agent|dialog", index=True)
  797. duration = FloatField(default=0, index=True)
  798. round = IntegerField(default=0, index=True)
  799. thumb_up = IntegerField(default=0, index=True)
  800. class Meta:
  801. db_table = "api_4_conversation"
  802. class UserCanvas(DataBaseModel):
  803. id = CharField(max_length=32, primary_key=True)
  804. avatar = TextField(null=True, help_text="avatar base64 string")
  805. user_id = CharField(max_length=255, null=False, help_text="user_id", index=True)
  806. title = CharField(max_length=255, null=True, help_text="Canvas title")
  807. description = TextField(null=True, help_text="Canvas description")
  808. canvas_type = CharField(max_length=32, null=True, help_text="Canvas type", index=True)
  809. dsl = JSONField(null=True, default={})
  810. class Meta:
  811. db_table = "user_canvas"
  812. class CanvasTemplate(DataBaseModel):
  813. id = CharField(max_length=32, primary_key=True)
  814. avatar = TextField(null=True, help_text="avatar base64 string")
  815. title = CharField(max_length=255, null=True, help_text="Canvas title")
  816. description = TextField(null=True, help_text="Canvas description")
  817. canvas_type = CharField(max_length=32, null=True, help_text="Canvas type", index=True)
  818. dsl = JSONField(null=True, default={})
  819. class Meta:
  820. db_table = "canvas_template"
  821. def migrate_db():
  822. with DB.transaction():
  823. migrator = DatabaseMigrator[settings.DATABASE_TYPE.upper()].value(DB)
  824. try:
  825. migrate(
  826. migrator.add_column('file', 'source_type', CharField(max_length=128, null=False, default="",
  827. help_text="where dose this document come from",
  828. index=True))
  829. )
  830. except Exception:
  831. pass
  832. try:
  833. migrate(
  834. migrator.add_column('tenant', 'rerank_id',
  835. CharField(max_length=128, null=False, default="BAAI/bge-reranker-v2-m3",
  836. help_text="default rerank model ID"))
  837. )
  838. except Exception:
  839. pass
  840. try:
  841. migrate(
  842. migrator.add_column('dialog', 'rerank_id', CharField(max_length=128, null=False, default="",
  843. help_text="default rerank model ID"))
  844. )
  845. except Exception:
  846. pass
  847. try:
  848. migrate(
  849. migrator.add_column('dialog', 'top_k', IntegerField(default=1024))
  850. )
  851. except Exception:
  852. pass
  853. try:
  854. migrate(
  855. migrator.alter_column_type('tenant_llm', 'api_key',
  856. CharField(max_length=1024, null=True, help_text="API KEY", index=True))
  857. )
  858. except Exception:
  859. pass
  860. try:
  861. migrate(
  862. migrator.add_column('api_token', 'source',
  863. CharField(max_length=16, null=True, help_text="none|agent|dialog", index=True))
  864. )
  865. except Exception:
  866. pass
  867. try:
  868. migrate(
  869. migrator.add_column("tenant","tts_id",
  870. CharField(max_length=256,null=True,help_text="default tts model ID",index=True))
  871. )
  872. except Exception:
  873. pass
  874. try:
  875. migrate(
  876. migrator.add_column('api_4_conversation', 'source',
  877. CharField(max_length=16, null=True, help_text="none|agent|dialog", index=True))
  878. )
  879. except Exception:
  880. pass
  881. try:
  882. DB.execute_sql('ALTER TABLE llm DROP PRIMARY KEY;')
  883. DB.execute_sql('ALTER TABLE llm ADD PRIMARY KEY (llm_name,fid);')
  884. except Exception:
  885. pass
  886. try:
  887. migrate(
  888. migrator.add_column('task', 'retry_count', IntegerField(default=0))
  889. )
  890. except Exception:
  891. pass
  892. try:
  893. migrate(
  894. migrator.alter_column_type('api_token', 'dialog_id',
  895. CharField(max_length=32, null=True, index=True))
  896. )
  897. except Exception:
  898. pass