You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

provider_configuration.py 79KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871
  1. import json
  2. import logging
  3. import re
  4. from collections import defaultdict
  5. from collections.abc import Iterator, Sequence
  6. from json import JSONDecodeError
  7. from typing import Optional
  8. from pydantic import BaseModel, ConfigDict, Field
  9. from sqlalchemy import func, select
  10. from sqlalchemy.orm import Session
  11. from constants import HIDDEN_VALUE
  12. from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, SimpleModelProviderEntity
  13. from core.entities.provider_entities import (
  14. CustomConfiguration,
  15. ModelSettings,
  16. SystemConfiguration,
  17. SystemConfigurationStatus,
  18. )
  19. from core.helper import encrypter
  20. from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType
  21. from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
  22. from core.model_runtime.entities.provider_entities import (
  23. ConfigurateMethod,
  24. CredentialFormSchema,
  25. FormType,
  26. ProviderEntity,
  27. )
  28. from core.model_runtime.model_providers.__base.ai_model import AIModel
  29. from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
  30. from core.plugin.entities.plugin import ModelProviderID
  31. from extensions.ext_database import db
  32. from libs.datetime_utils import naive_utc_now
  33. from models.provider import (
  34. LoadBalancingModelConfig,
  35. Provider,
  36. ProviderCredential,
  37. ProviderModel,
  38. ProviderModelCredential,
  39. ProviderModelSetting,
  40. ProviderType,
  41. TenantPreferredModelProvider,
  42. )
  43. from services.enterprise.plugin_manager_service import PluginCredentialType
  44. logger = logging.getLogger(__name__)
  45. original_provider_configurate_methods: dict[str, list[ConfigurateMethod]] = {}
  46. class ProviderConfiguration(BaseModel):
  47. """
  48. Provider configuration entity for managing model provider settings.
  49. This class handles:
  50. - Provider credentials CRUD and switch
  51. - Custom Model credentials CRUD and switch
  52. - System vs custom provider switching
  53. - Load balancing configurations
  54. - Model enablement/disablement
  55. TODO: lots of logic in a BaseModel entity should be separated, the exceptions should be classified
  56. """
  57. tenant_id: str
  58. provider: ProviderEntity
  59. preferred_provider_type: ProviderType
  60. using_provider_type: ProviderType
  61. system_configuration: SystemConfiguration
  62. custom_configuration: CustomConfiguration
  63. model_settings: list[ModelSettings]
  64. # pydantic configs
  65. model_config = ConfigDict(protected_namespaces=())
  66. def __init__(self, **data):
  67. super().__init__(**data)
  68. if self.provider.provider not in original_provider_configurate_methods:
  69. original_provider_configurate_methods[self.provider.provider] = []
  70. for configurate_method in self.provider.configurate_methods:
  71. original_provider_configurate_methods[self.provider.provider].append(configurate_method)
  72. if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]:
  73. if (
  74. any(
  75. len(quota_configuration.restrict_models) > 0
  76. for quota_configuration in self.system_configuration.quota_configurations
  77. )
  78. and ConfigurateMethod.PREDEFINED_MODEL not in self.provider.configurate_methods
  79. ):
  80. self.provider.configurate_methods.append(ConfigurateMethod.PREDEFINED_MODEL)
  81. def get_current_credentials(self, model_type: ModelType, model: str) -> Optional[dict]:
  82. """
  83. Get current credentials.
  84. :param model_type: model type
  85. :param model: model name
  86. :return:
  87. """
  88. if self.model_settings:
  89. # check if model is disabled by admin
  90. for model_setting in self.model_settings:
  91. if model_setting.model_type == model_type and model_setting.model == model:
  92. if not model_setting.enabled:
  93. raise ValueError(f"Model {model} is disabled.")
  94. if self.using_provider_type == ProviderType.SYSTEM:
  95. restrict_models = []
  96. for quota_configuration in self.system_configuration.quota_configurations:
  97. if self.system_configuration.current_quota_type != quota_configuration.quota_type:
  98. continue
  99. restrict_models = quota_configuration.restrict_models
  100. copy_credentials = (
  101. self.system_configuration.credentials.copy() if self.system_configuration.credentials else {}
  102. )
  103. if restrict_models:
  104. for restrict_model in restrict_models:
  105. if (
  106. restrict_model.model_type == model_type
  107. and restrict_model.model == model
  108. and restrict_model.base_model_name
  109. ):
  110. copy_credentials["base_model_name"] = restrict_model.base_model_name
  111. return copy_credentials
  112. else:
  113. credentials = None
  114. current_credential_id = None
  115. if self.custom_configuration.models:
  116. for model_configuration in self.custom_configuration.models:
  117. if model_configuration.model_type == model_type and model_configuration.model == model:
  118. credentials = model_configuration.credentials
  119. current_credential_id = model_configuration.current_credential_id
  120. break
  121. if not credentials and self.custom_configuration.provider:
  122. credentials = self.custom_configuration.provider.credentials
  123. current_credential_id = self.custom_configuration.provider.current_credential_id
  124. if current_credential_id:
  125. from core.helper.credential_utils import check_credential_policy_compliance
  126. check_credential_policy_compliance(
  127. credential_id=current_credential_id,
  128. provider=self.provider.provider,
  129. credential_type=PluginCredentialType.MODEL,
  130. )
  131. else:
  132. # no current credential id, check all available credentials
  133. if self.custom_configuration.provider:
  134. for credential_configuration in self.custom_configuration.provider.available_credentials:
  135. from core.helper.credential_utils import check_credential_policy_compliance
  136. check_credential_policy_compliance(
  137. credential_id=credential_configuration.credential_id,
  138. provider=self.provider.provider,
  139. credential_type=PluginCredentialType.MODEL,
  140. )
  141. return credentials
  142. def get_system_configuration_status(self) -> Optional[SystemConfigurationStatus]:
  143. """
  144. Get system configuration status.
  145. :return:
  146. """
  147. if self.system_configuration.enabled is False:
  148. return SystemConfigurationStatus.UNSUPPORTED
  149. current_quota_type = self.system_configuration.current_quota_type
  150. current_quota_configuration = next(
  151. (q for q in self.system_configuration.quota_configurations if q.quota_type == current_quota_type), None
  152. )
  153. if current_quota_configuration is None:
  154. return None
  155. if not current_quota_configuration:
  156. return SystemConfigurationStatus.UNSUPPORTED
  157. return (
  158. SystemConfigurationStatus.ACTIVE
  159. if current_quota_configuration.is_valid
  160. else SystemConfigurationStatus.QUOTA_EXCEEDED
  161. )
  162. def is_custom_configuration_available(self) -> bool:
  163. """
  164. Check custom configuration available.
  165. :return:
  166. """
  167. has_provider_credentials = (
  168. self.custom_configuration.provider is not None
  169. and len(self.custom_configuration.provider.available_credentials) > 0
  170. )
  171. has_model_configurations = len(self.custom_configuration.models) > 0
  172. return has_provider_credentials or has_model_configurations
  173. def _get_provider_record(self, session: Session) -> Provider | None:
  174. """
  175. Get custom provider record.
  176. """
  177. # get provider
  178. model_provider_id = ModelProviderID(self.provider.provider)
  179. provider_names = [self.provider.provider]
  180. if model_provider_id.is_langgenius():
  181. provider_names.append(model_provider_id.provider_name)
  182. stmt = select(Provider).where(
  183. Provider.tenant_id == self.tenant_id,
  184. Provider.provider_type == ProviderType.CUSTOM.value,
  185. Provider.provider_name.in_(provider_names),
  186. )
  187. return session.execute(stmt).scalar_one_or_none()
  188. def _get_specific_provider_credential(self, credential_id: str) -> dict | None:
  189. """
  190. Get a specific provider credential by ID.
  191. :param credential_id: Credential ID
  192. :return:
  193. """
  194. # Extract secret variables from provider credential schema
  195. credential_secret_variables = self.extract_secret_variables(
  196. self.provider.provider_credential_schema.credential_form_schemas
  197. if self.provider.provider_credential_schema
  198. else []
  199. )
  200. with Session(db.engine) as session:
  201. # Prefer the actual provider record name if exists (to handle aliased provider names)
  202. provider_record = self._get_provider_record(session)
  203. provider_name = provider_record.provider_name if provider_record else self.provider.provider
  204. stmt = select(ProviderCredential).where(
  205. ProviderCredential.id == credential_id,
  206. ProviderCredential.tenant_id == self.tenant_id,
  207. ProviderCredential.provider_name == provider_name,
  208. )
  209. credential = session.execute(stmt).scalar_one_or_none()
  210. if not credential or not credential.encrypted_config:
  211. raise ValueError(f"Credential with id {credential_id} not found.")
  212. try:
  213. credentials = json.loads(credential.encrypted_config)
  214. except JSONDecodeError:
  215. credentials = {}
  216. # Decrypt secret variables
  217. for key in credential_secret_variables:
  218. if key in credentials and credentials[key] is not None:
  219. try:
  220. credentials[key] = encrypter.decrypt_token(tenant_id=self.tenant_id, token=credentials[key])
  221. except Exception:
  222. pass
  223. return self.obfuscated_credentials(
  224. credentials=credentials,
  225. credential_form_schemas=self.provider.provider_credential_schema.credential_form_schemas
  226. if self.provider.provider_credential_schema
  227. else [],
  228. )
  229. def _check_provider_credential_name_exists(
  230. self, credential_name: str, session: Session, exclude_id: str | None = None
  231. ) -> bool:
  232. """
  233. not allowed same name when create or update a credential
  234. """
  235. stmt = select(ProviderCredential.id).where(
  236. ProviderCredential.tenant_id == self.tenant_id,
  237. ProviderCredential.provider_name == self.provider.provider,
  238. ProviderCredential.credential_name == credential_name,
  239. )
  240. if exclude_id:
  241. stmt = stmt.where(ProviderCredential.id != exclude_id)
  242. return session.execute(stmt).scalar_one_or_none() is not None
  243. def get_provider_credential(self, credential_id: str | None = None) -> dict | None:
  244. """
  245. Get provider credentials.
  246. :param credential_id: if provided, return the specified credential
  247. :return:
  248. """
  249. if credential_id:
  250. return self._get_specific_provider_credential(credential_id)
  251. # Default behavior: return current active provider credentials
  252. credentials = self.custom_configuration.provider.credentials if self.custom_configuration.provider else {}
  253. return self.obfuscated_credentials(
  254. credentials=credentials,
  255. credential_form_schemas=self.provider.provider_credential_schema.credential_form_schemas
  256. if self.provider.provider_credential_schema
  257. else [],
  258. )
  259. def validate_provider_credentials(self, credentials: dict, credential_id: str = "", session: Session | None = None):
  260. """
  261. Validate custom credentials.
  262. :param credentials: provider credentials
  263. :param credential_id: (Optional)If provided, can use existing credential's hidden api key to validate
  264. :param session: optional database session
  265. :return:
  266. """
  267. def _validate(s: Session):
  268. # Get provider credential secret variables
  269. provider_credential_secret_variables = self.extract_secret_variables(
  270. self.provider.provider_credential_schema.credential_form_schemas
  271. if self.provider.provider_credential_schema
  272. else []
  273. )
  274. if credential_id:
  275. try:
  276. stmt = select(ProviderCredential).where(
  277. ProviderCredential.tenant_id == self.tenant_id,
  278. ProviderCredential.provider_name == self.provider.provider,
  279. ProviderCredential.id == credential_id,
  280. )
  281. credential_record = s.execute(stmt).scalar_one_or_none()
  282. # fix origin data
  283. if credential_record and credential_record.encrypted_config:
  284. if not credential_record.encrypted_config.startswith("{"):
  285. original_credentials = {"openai_api_key": credential_record.encrypted_config}
  286. else:
  287. original_credentials = json.loads(credential_record.encrypted_config)
  288. else:
  289. original_credentials = {}
  290. except JSONDecodeError:
  291. original_credentials = {}
  292. # encrypt credentials
  293. for key, value in credentials.items():
  294. if key in provider_credential_secret_variables:
  295. # if send [__HIDDEN__] in secret input, it will be same as original value
  296. if value == HIDDEN_VALUE and key in original_credentials:
  297. credentials[key] = encrypter.decrypt_token(
  298. tenant_id=self.tenant_id, token=original_credentials[key]
  299. )
  300. model_provider_factory = ModelProviderFactory(self.tenant_id)
  301. validated_credentials = model_provider_factory.provider_credentials_validate(
  302. provider=self.provider.provider, credentials=credentials
  303. )
  304. for key, value in validated_credentials.items():
  305. if key in provider_credential_secret_variables:
  306. validated_credentials[key] = encrypter.encrypt_token(self.tenant_id, value)
  307. return validated_credentials
  308. if session:
  309. return _validate(session)
  310. else:
  311. with Session(db.engine) as new_session:
  312. return _validate(new_session)
  313. def _generate_provider_credential_name(self, session) -> str:
  314. """
  315. Generate a unique credential name for provider.
  316. :return: credential name
  317. """
  318. return self._generate_next_api_key_name(
  319. session=session,
  320. query_factory=lambda: select(ProviderCredential).where(
  321. ProviderCredential.tenant_id == self.tenant_id,
  322. ProviderCredential.provider_name == self.provider.provider,
  323. ),
  324. )
  325. def _generate_custom_model_credential_name(self, model: str, model_type: ModelType, session) -> str:
  326. """
  327. Generate a unique credential name for custom model.
  328. :return: credential name
  329. """
  330. return self._generate_next_api_key_name(
  331. session=session,
  332. query_factory=lambda: select(ProviderModelCredential).where(
  333. ProviderModelCredential.tenant_id == self.tenant_id,
  334. ProviderModelCredential.provider_name == self.provider.provider,
  335. ProviderModelCredential.model_name == model,
  336. ProviderModelCredential.model_type == model_type.to_origin_model_type(),
  337. ),
  338. )
  339. def _generate_next_api_key_name(self, session, query_factory) -> str:
  340. """
  341. Generate next available API KEY name by finding the highest numbered suffix.
  342. :param session: database session
  343. :param query_factory: function that returns the SQLAlchemy query
  344. :return: next available API KEY name
  345. """
  346. try:
  347. stmt = query_factory()
  348. credential_records = session.execute(stmt).scalars().all()
  349. if not credential_records:
  350. return "API KEY 1"
  351. # Extract numbers from API KEY pattern using list comprehension
  352. pattern = re.compile(r"^API KEY\s+(\d+)$")
  353. numbers = [
  354. int(match.group(1))
  355. for cr in credential_records
  356. if cr.credential_name and (match := pattern.match(cr.credential_name.strip()))
  357. ]
  358. # Return next sequential number
  359. next_number = max(numbers, default=0) + 1
  360. return f"API KEY {next_number}"
  361. except Exception as e:
  362. logger.warning("Error generating next credential name: %s", str(e))
  363. return "API KEY 1"
  364. def create_provider_credential(self, credentials: dict, credential_name: str | None):
  365. """
  366. Add custom provider credentials.
  367. :param credentials: provider credentials
  368. :param credential_name: credential name
  369. :return:
  370. """
  371. with Session(db.engine) as session:
  372. if credential_name:
  373. if self._check_provider_credential_name_exists(credential_name=credential_name, session=session):
  374. raise ValueError(f"Credential with name '{credential_name}' already exists.")
  375. else:
  376. credential_name = self._generate_provider_credential_name(session)
  377. credentials = self.validate_provider_credentials(credentials=credentials, session=session)
  378. provider_record = self._get_provider_record(session)
  379. try:
  380. new_record = ProviderCredential(
  381. tenant_id=self.tenant_id,
  382. provider_name=self.provider.provider,
  383. encrypted_config=json.dumps(credentials),
  384. credential_name=credential_name,
  385. )
  386. session.add(new_record)
  387. session.flush()
  388. if not provider_record:
  389. # If provider record does not exist, create it
  390. provider_record = Provider(
  391. tenant_id=self.tenant_id,
  392. provider_name=self.provider.provider,
  393. provider_type=ProviderType.CUSTOM.value,
  394. is_valid=True,
  395. credential_id=new_record.id,
  396. )
  397. session.add(provider_record)
  398. provider_model_credentials_cache = ProviderCredentialsCache(
  399. tenant_id=self.tenant_id,
  400. identity_id=provider_record.id,
  401. cache_type=ProviderCredentialsCacheType.PROVIDER,
  402. )
  403. provider_model_credentials_cache.delete()
  404. self.switch_preferred_provider_type(provider_type=ProviderType.CUSTOM, session=session)
  405. session.commit()
  406. except Exception:
  407. session.rollback()
  408. raise
  409. def update_provider_credential(
  410. self,
  411. credentials: dict,
  412. credential_id: str,
  413. credential_name: str | None,
  414. ):
  415. """
  416. update a saved provider credential (by credential_id).
  417. :param credentials: provider credentials
  418. :param credential_id: credential id
  419. :param credential_name: credential name
  420. :return:
  421. """
  422. with Session(db.engine) as session:
  423. if credential_name and self._check_provider_credential_name_exists(
  424. credential_name=credential_name, session=session, exclude_id=credential_id
  425. ):
  426. raise ValueError(f"Credential with name '{credential_name}' already exists.")
  427. credentials = self.validate_provider_credentials(
  428. credentials=credentials, credential_id=credential_id, session=session
  429. )
  430. provider_record = self._get_provider_record(session)
  431. stmt = select(ProviderCredential).where(
  432. ProviderCredential.id == credential_id,
  433. ProviderCredential.tenant_id == self.tenant_id,
  434. ProviderCredential.provider_name == self.provider.provider,
  435. )
  436. # Get the credential record to update
  437. credential_record = session.execute(stmt).scalar_one_or_none()
  438. if not credential_record:
  439. raise ValueError("Credential record not found.")
  440. try:
  441. # Update credential
  442. credential_record.encrypted_config = json.dumps(credentials)
  443. credential_record.updated_at = naive_utc_now()
  444. if credential_name:
  445. credential_record.credential_name = credential_name
  446. session.commit()
  447. if provider_record and provider_record.credential_id == credential_id:
  448. provider_model_credentials_cache = ProviderCredentialsCache(
  449. tenant_id=self.tenant_id,
  450. identity_id=provider_record.id,
  451. cache_type=ProviderCredentialsCacheType.PROVIDER,
  452. )
  453. provider_model_credentials_cache.delete()
  454. self._update_load_balancing_configs_with_credential(
  455. credential_id=credential_id,
  456. credential_record=credential_record,
  457. credential_source="provider",
  458. session=session,
  459. )
  460. except Exception:
  461. session.rollback()
  462. raise
  463. def _update_load_balancing_configs_with_credential(
  464. self,
  465. credential_id: str,
  466. credential_record: ProviderCredential | ProviderModelCredential,
  467. credential_source: str,
  468. session: Session,
  469. ):
  470. """
  471. Update load balancing configurations that reference the given credential_id.
  472. :param credential_id: credential id
  473. :param credential_record: the encrypted_config and credential_name
  474. :param credential_source: the credential comes from the provider_credential(`provider`)
  475. or the provider_model_credential(`custom_model`)
  476. :param session: the database session
  477. :return:
  478. """
  479. # Find all load balancing configs that use this credential_id
  480. stmt = select(LoadBalancingModelConfig).where(
  481. LoadBalancingModelConfig.tenant_id == self.tenant_id,
  482. LoadBalancingModelConfig.provider_name == self.provider.provider,
  483. LoadBalancingModelConfig.credential_id == credential_id,
  484. LoadBalancingModelConfig.credential_source_type == credential_source,
  485. )
  486. load_balancing_configs = session.execute(stmt).scalars().all()
  487. if not load_balancing_configs:
  488. return
  489. # Update each load balancing config with the new credentials
  490. for lb_config in load_balancing_configs:
  491. # Update the encrypted_config with the new credentials
  492. lb_config.encrypted_config = credential_record.encrypted_config
  493. lb_config.name = credential_record.credential_name
  494. lb_config.updated_at = naive_utc_now()
  495. # Clear cache for this load balancing config
  496. lb_credentials_cache = ProviderCredentialsCache(
  497. tenant_id=self.tenant_id,
  498. identity_id=lb_config.id,
  499. cache_type=ProviderCredentialsCacheType.LOAD_BALANCING_MODEL,
  500. )
  501. lb_credentials_cache.delete()
  502. session.commit()
  503. def delete_provider_credential(self, credential_id: str):
  504. """
  505. Delete a saved provider credential (by credential_id).
  506. :param credential_id: credential id
  507. :return:
  508. """
  509. with Session(db.engine) as session:
  510. stmt = select(ProviderCredential).where(
  511. ProviderCredential.id == credential_id,
  512. ProviderCredential.tenant_id == self.tenant_id,
  513. ProviderCredential.provider_name == self.provider.provider,
  514. )
  515. # Get the credential record to update
  516. credential_record = session.execute(stmt).scalar_one_or_none()
  517. if not credential_record:
  518. raise ValueError("Credential record not found.")
  519. # Check if this credential is used in load balancing configs
  520. lb_stmt = select(LoadBalancingModelConfig).where(
  521. LoadBalancingModelConfig.tenant_id == self.tenant_id,
  522. LoadBalancingModelConfig.provider_name == self.provider.provider,
  523. LoadBalancingModelConfig.credential_id == credential_id,
  524. LoadBalancingModelConfig.credential_source_type == "provider",
  525. )
  526. lb_configs_using_credential = session.execute(lb_stmt).scalars().all()
  527. try:
  528. for lb_config in lb_configs_using_credential:
  529. lb_credentials_cache = ProviderCredentialsCache(
  530. tenant_id=self.tenant_id,
  531. identity_id=lb_config.id,
  532. cache_type=ProviderCredentialsCacheType.LOAD_BALANCING_MODEL,
  533. )
  534. lb_credentials_cache.delete()
  535. session.delete(lb_config)
  536. # Check if this is the currently active credential
  537. provider_record = self._get_provider_record(session)
  538. # Check available credentials count BEFORE deleting
  539. # if this is the last credential, we need to delete the provider record
  540. count_stmt = select(func.count(ProviderCredential.id)).where(
  541. ProviderCredential.tenant_id == self.tenant_id,
  542. ProviderCredential.provider_name == self.provider.provider,
  543. )
  544. available_credentials_count = session.execute(count_stmt).scalar() or 0
  545. session.delete(credential_record)
  546. if provider_record and available_credentials_count <= 1:
  547. # If all credentials are deleted, delete the provider record
  548. session.delete(provider_record)
  549. provider_model_credentials_cache = ProviderCredentialsCache(
  550. tenant_id=self.tenant_id,
  551. identity_id=provider_record.id,
  552. cache_type=ProviderCredentialsCacheType.PROVIDER,
  553. )
  554. provider_model_credentials_cache.delete()
  555. self.switch_preferred_provider_type(provider_type=ProviderType.SYSTEM, session=session)
  556. elif provider_record and provider_record.credential_id == credential_id:
  557. provider_record.credential_id = None
  558. provider_record.updated_at = naive_utc_now()
  559. provider_model_credentials_cache = ProviderCredentialsCache(
  560. tenant_id=self.tenant_id,
  561. identity_id=provider_record.id,
  562. cache_type=ProviderCredentialsCacheType.PROVIDER,
  563. )
  564. provider_model_credentials_cache.delete()
  565. self.switch_preferred_provider_type(provider_type=ProviderType.SYSTEM, session=session)
  566. session.commit()
  567. except Exception:
  568. session.rollback()
  569. raise
  570. def switch_active_provider_credential(self, credential_id: str):
  571. """
  572. Switch active provider credential (copy the selected one into current active snapshot).
  573. :param credential_id: credential id
  574. :return:
  575. """
  576. with Session(db.engine) as session:
  577. stmt = select(ProviderCredential).where(
  578. ProviderCredential.id == credential_id,
  579. ProviderCredential.tenant_id == self.tenant_id,
  580. ProviderCredential.provider_name == self.provider.provider,
  581. )
  582. credential_record = session.execute(stmt).scalar_one_or_none()
  583. if not credential_record:
  584. raise ValueError("Credential record not found.")
  585. provider_record = self._get_provider_record(session)
  586. if not provider_record:
  587. raise ValueError("Provider record not found.")
  588. try:
  589. provider_record.credential_id = credential_record.id
  590. provider_record.updated_at = naive_utc_now()
  591. session.commit()
  592. provider_model_credentials_cache = ProviderCredentialsCache(
  593. tenant_id=self.tenant_id,
  594. identity_id=provider_record.id,
  595. cache_type=ProviderCredentialsCacheType.PROVIDER,
  596. )
  597. provider_model_credentials_cache.delete()
  598. self.switch_preferred_provider_type(ProviderType.CUSTOM, session=session)
  599. except Exception:
  600. session.rollback()
  601. raise
  602. def _get_custom_model_record(
  603. self,
  604. model_type: ModelType,
  605. model: str,
  606. session: Session,
  607. ) -> ProviderModel | None:
  608. """
  609. Get custom model credentials.
  610. """
  611. # get provider model
  612. model_provider_id = ModelProviderID(self.provider.provider)
  613. provider_names = [self.provider.provider]
  614. if model_provider_id.is_langgenius():
  615. provider_names.append(model_provider_id.provider_name)
  616. stmt = select(ProviderModel).where(
  617. ProviderModel.tenant_id == self.tenant_id,
  618. ProviderModel.provider_name.in_(provider_names),
  619. ProviderModel.model_name == model,
  620. ProviderModel.model_type == model_type.to_origin_model_type(),
  621. )
  622. return session.execute(stmt).scalar_one_or_none()
  623. def _get_specific_custom_model_credential(
  624. self, model_type: ModelType, model: str, credential_id: str
  625. ) -> dict | None:
  626. """
  627. Get a specific provider credential by ID.
  628. :param credential_id: Credential ID
  629. :return:
  630. """
  631. model_credential_secret_variables = self.extract_secret_variables(
  632. self.provider.model_credential_schema.credential_form_schemas
  633. if self.provider.model_credential_schema
  634. else []
  635. )
  636. with Session(db.engine) as session:
  637. stmt = select(ProviderModelCredential).where(
  638. ProviderModelCredential.id == credential_id,
  639. ProviderModelCredential.tenant_id == self.tenant_id,
  640. ProviderModelCredential.provider_name == self.provider.provider,
  641. ProviderModelCredential.model_name == model,
  642. ProviderModelCredential.model_type == model_type.to_origin_model_type(),
  643. )
  644. credential_record = session.execute(stmt).scalar_one_or_none()
  645. if not credential_record or not credential_record.encrypted_config:
  646. raise ValueError(f"Credential with id {credential_id} not found.")
  647. try:
  648. credentials = json.loads(credential_record.encrypted_config)
  649. except JSONDecodeError:
  650. credentials = {}
  651. # Decrypt secret variables
  652. for key in model_credential_secret_variables:
  653. if key in credentials and credentials[key] is not None:
  654. try:
  655. credentials[key] = encrypter.decrypt_token(tenant_id=self.tenant_id, token=credentials[key])
  656. except Exception:
  657. pass
  658. current_credential_id = credential_record.id
  659. current_credential_name = credential_record.credential_name
  660. credentials = self.obfuscated_credentials(
  661. credentials=credentials,
  662. credential_form_schemas=self.provider.model_credential_schema.credential_form_schemas
  663. if self.provider.model_credential_schema
  664. else [],
  665. )
  666. return {
  667. "current_credential_id": current_credential_id,
  668. "current_credential_name": current_credential_name,
  669. "credentials": credentials,
  670. }
  671. def _check_custom_model_credential_name_exists(
  672. self, model_type: ModelType, model: str, credential_name: str, session: Session, exclude_id: str | None = None
  673. ) -> bool:
  674. """
  675. not allowed same name when create or update a credential
  676. """
  677. stmt = select(ProviderModelCredential).where(
  678. ProviderModelCredential.tenant_id == self.tenant_id,
  679. ProviderModelCredential.provider_name == self.provider.provider,
  680. ProviderModelCredential.model_name == model,
  681. ProviderModelCredential.model_type == model_type.to_origin_model_type(),
  682. ProviderModelCredential.credential_name == credential_name,
  683. )
  684. if exclude_id:
  685. stmt = stmt.where(ProviderModelCredential.id != exclude_id)
  686. return session.execute(stmt).scalar_one_or_none() is not None
  687. def get_custom_model_credential(
  688. self, model_type: ModelType, model: str, credential_id: str | None
  689. ) -> Optional[dict]:
  690. """
  691. Get custom model credentials.
  692. :param model_type: model type
  693. :param model: model name
  694. :return:
  695. """
  696. # If credential_id is provided, return the specific credential
  697. if credential_id:
  698. return self._get_specific_custom_model_credential(
  699. model_type=model_type, model=model, credential_id=credential_id
  700. )
  701. for model_configuration in self.custom_configuration.models:
  702. if (
  703. model_configuration.model_type == model_type
  704. and model_configuration.model == model
  705. and model_configuration.credentials
  706. ):
  707. current_credential_id = model_configuration.current_credential_id
  708. current_credential_name = model_configuration.current_credential_name
  709. credentials = self.obfuscated_credentials(
  710. credentials=model_configuration.credentials,
  711. credential_form_schemas=self.provider.model_credential_schema.credential_form_schemas
  712. if self.provider.model_credential_schema
  713. else [],
  714. )
  715. return {
  716. "current_credential_id": current_credential_id,
  717. "current_credential_name": current_credential_name,
  718. "credentials": credentials,
  719. }
  720. return None
  721. def validate_custom_model_credentials(
  722. self,
  723. model_type: ModelType,
  724. model: str,
  725. credentials: dict,
  726. credential_id: str = "",
  727. session: Session | None = None,
  728. ):
  729. """
  730. Validate custom model credentials.
  731. :param model_type: model type
  732. :param model: model name
  733. :param credentials: model credentials dict
  734. :param credential_id: (Optional)If provided, can use existing credential's hidden api key to validate
  735. :return:
  736. """
  737. def _validate(s: Session):
  738. # Get provider credential secret variables
  739. provider_credential_secret_variables = self.extract_secret_variables(
  740. self.provider.model_credential_schema.credential_form_schemas
  741. if self.provider.model_credential_schema
  742. else []
  743. )
  744. if credential_id:
  745. try:
  746. stmt = select(ProviderModelCredential).where(
  747. ProviderModelCredential.id == credential_id,
  748. ProviderModelCredential.tenant_id == self.tenant_id,
  749. ProviderModelCredential.provider_name == self.provider.provider,
  750. ProviderModelCredential.model_name == model,
  751. ProviderModelCredential.model_type == model_type.to_origin_model_type(),
  752. )
  753. credential_record = s.execute(stmt).scalar_one_or_none()
  754. original_credentials = (
  755. json.loads(credential_record.encrypted_config)
  756. if credential_record and credential_record.encrypted_config
  757. else {}
  758. )
  759. except JSONDecodeError:
  760. original_credentials = {}
  761. # decrypt credentials
  762. for key, value in credentials.items():
  763. if key in provider_credential_secret_variables:
  764. # if send [__HIDDEN__] in secret input, it will be same as original value
  765. if value == HIDDEN_VALUE and key in original_credentials:
  766. credentials[key] = encrypter.decrypt_token(
  767. tenant_id=self.tenant_id, token=original_credentials[key]
  768. )
  769. model_provider_factory = ModelProviderFactory(self.tenant_id)
  770. validated_credentials = model_provider_factory.model_credentials_validate(
  771. provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials
  772. )
  773. for key, value in validated_credentials.items():
  774. if key in provider_credential_secret_variables:
  775. validated_credentials[key] = encrypter.encrypt_token(self.tenant_id, value)
  776. return validated_credentials
  777. if session:
  778. return _validate(session)
  779. else:
  780. with Session(db.engine) as new_session:
  781. return _validate(new_session)
  782. def create_custom_model_credential(
  783. self, model_type: ModelType, model: str, credentials: dict, credential_name: str | None
  784. ) -> None:
  785. """
  786. Create a custom model credential.
  787. :param model_type: model type
  788. :param model: model name
  789. :param credentials: model credentials dict
  790. :return:
  791. """
  792. with Session(db.engine) as session:
  793. if credential_name:
  794. if self._check_custom_model_credential_name_exists(
  795. model=model, model_type=model_type, credential_name=credential_name, session=session
  796. ):
  797. raise ValueError(f"Model credential with name '{credential_name}' already exists for {model}.")
  798. else:
  799. credential_name = self._generate_custom_model_credential_name(
  800. model=model, model_type=model_type, session=session
  801. )
  802. # validate custom model config
  803. credentials = self.validate_custom_model_credentials(
  804. model_type=model_type, model=model, credentials=credentials, session=session
  805. )
  806. provider_model_record = self._get_custom_model_record(model_type=model_type, model=model, session=session)
  807. try:
  808. credential = ProviderModelCredential(
  809. tenant_id=self.tenant_id,
  810. provider_name=self.provider.provider,
  811. model_name=model,
  812. model_type=model_type.to_origin_model_type(),
  813. encrypted_config=json.dumps(credentials),
  814. credential_name=credential_name,
  815. )
  816. session.add(credential)
  817. session.flush()
  818. # save provider model
  819. if not provider_model_record:
  820. provider_model_record = ProviderModel(
  821. tenant_id=self.tenant_id,
  822. provider_name=self.provider.provider,
  823. model_name=model,
  824. model_type=model_type.to_origin_model_type(),
  825. credential_id=credential.id,
  826. is_valid=True,
  827. )
  828. session.add(provider_model_record)
  829. session.commit()
  830. provider_model_credentials_cache = ProviderCredentialsCache(
  831. tenant_id=self.tenant_id,
  832. identity_id=provider_model_record.id,
  833. cache_type=ProviderCredentialsCacheType.MODEL,
  834. )
  835. provider_model_credentials_cache.delete()
  836. except Exception:
  837. session.rollback()
  838. raise
  839. def update_custom_model_credential(
  840. self, model_type: ModelType, model: str, credentials: dict, credential_name: str | None, credential_id: str
  841. ) -> None:
  842. """
  843. Update a custom model credential.
  844. :param model_type: model type
  845. :param model: model name
  846. :param credentials: model credentials dict
  847. :param credential_name: credential name
  848. :param credential_id: credential id
  849. :return:
  850. """
  851. with Session(db.engine) as session:
  852. if credential_name and self._check_custom_model_credential_name_exists(
  853. model=model,
  854. model_type=model_type,
  855. credential_name=credential_name,
  856. session=session,
  857. exclude_id=credential_id,
  858. ):
  859. raise ValueError(f"Model credential with name '{credential_name}' already exists for {model}.")
  860. # validate custom model config
  861. credentials = self.validate_custom_model_credentials(
  862. model_type=model_type,
  863. model=model,
  864. credentials=credentials,
  865. credential_id=credential_id,
  866. session=session,
  867. )
  868. provider_model_record = self._get_custom_model_record(model_type=model_type, model=model, session=session)
  869. stmt = select(ProviderModelCredential).where(
  870. ProviderModelCredential.id == credential_id,
  871. ProviderModelCredential.tenant_id == self.tenant_id,
  872. ProviderModelCredential.provider_name == self.provider.provider,
  873. ProviderModelCredential.model_name == model,
  874. ProviderModelCredential.model_type == model_type.to_origin_model_type(),
  875. )
  876. credential_record = session.execute(stmt).scalar_one_or_none()
  877. if not credential_record:
  878. raise ValueError("Credential record not found.")
  879. try:
  880. # Update credential
  881. credential_record.encrypted_config = json.dumps(credentials)
  882. credential_record.updated_at = naive_utc_now()
  883. if credential_name:
  884. credential_record.credential_name = credential_name
  885. session.commit()
  886. if provider_model_record and provider_model_record.credential_id == credential_id:
  887. provider_model_credentials_cache = ProviderCredentialsCache(
  888. tenant_id=self.tenant_id,
  889. identity_id=provider_model_record.id,
  890. cache_type=ProviderCredentialsCacheType.MODEL,
  891. )
  892. provider_model_credentials_cache.delete()
  893. self._update_load_balancing_configs_with_credential(
  894. credential_id=credential_id,
  895. credential_record=credential_record,
  896. credential_source="custom_model",
  897. session=session,
  898. )
  899. except Exception:
  900. session.rollback()
  901. raise
  902. def delete_custom_model_credential(self, model_type: ModelType, model: str, credential_id: str):
  903. """
  904. Delete a saved provider credential (by credential_id).
  905. :param credential_id: credential id
  906. :return:
  907. """
  908. with Session(db.engine) as session:
  909. stmt = select(ProviderModelCredential).where(
  910. ProviderModelCredential.id == credential_id,
  911. ProviderModelCredential.tenant_id == self.tenant_id,
  912. ProviderModelCredential.provider_name == self.provider.provider,
  913. ProviderModelCredential.model_name == model,
  914. ProviderModelCredential.model_type == model_type.to_origin_model_type(),
  915. )
  916. credential_record = session.execute(stmt).scalar_one_or_none()
  917. if not credential_record:
  918. raise ValueError("Credential record not found.")
  919. lb_stmt = select(LoadBalancingModelConfig).where(
  920. LoadBalancingModelConfig.tenant_id == self.tenant_id,
  921. LoadBalancingModelConfig.provider_name == self.provider.provider,
  922. LoadBalancingModelConfig.credential_id == credential_id,
  923. LoadBalancingModelConfig.credential_source_type == "custom_model",
  924. )
  925. lb_configs_using_credential = session.execute(lb_stmt).scalars().all()
  926. try:
  927. for lb_config in lb_configs_using_credential:
  928. lb_credentials_cache = ProviderCredentialsCache(
  929. tenant_id=self.tenant_id,
  930. identity_id=lb_config.id,
  931. cache_type=ProviderCredentialsCacheType.LOAD_BALANCING_MODEL,
  932. )
  933. lb_credentials_cache.delete()
  934. session.delete(lb_config)
  935. # Check if this is the currently active credential
  936. provider_model_record = self._get_custom_model_record(model_type, model, session=session)
  937. # Check available credentials count BEFORE deleting
  938. # if this is the last credential, we need to delete the custom model record
  939. count_stmt = select(func.count(ProviderModelCredential.id)).where(
  940. ProviderModelCredential.tenant_id == self.tenant_id,
  941. ProviderModelCredential.provider_name == self.provider.provider,
  942. ProviderModelCredential.model_name == model,
  943. ProviderModelCredential.model_type == model_type.to_origin_model_type(),
  944. )
  945. available_credentials_count = session.execute(count_stmt).scalar() or 0
  946. session.delete(credential_record)
  947. if provider_model_record and available_credentials_count <= 1:
  948. # If all credentials are deleted, delete the custom model record
  949. session.delete(provider_model_record)
  950. elif provider_model_record and provider_model_record.credential_id == credential_id:
  951. provider_model_record.credential_id = None
  952. provider_model_record.updated_at = naive_utc_now()
  953. provider_model_credentials_cache = ProviderCredentialsCache(
  954. tenant_id=self.tenant_id,
  955. identity_id=provider_model_record.id,
  956. cache_type=ProviderCredentialsCacheType.PROVIDER,
  957. )
  958. provider_model_credentials_cache.delete()
  959. session.commit()
  960. except Exception:
  961. session.rollback()
  962. raise
  963. def add_model_credential_to_model(self, model_type: ModelType, model: str, credential_id: str):
  964. """
  965. if model list exist this custom model, switch the custom model credential.
  966. if model list not exist this custom model, use the credential to add a new custom model record.
  967. :param model_type: model type
  968. :param model: model name
  969. :param credential_id: credential id
  970. :return:
  971. """
  972. with Session(db.engine) as session:
  973. stmt = select(ProviderModelCredential).where(
  974. ProviderModelCredential.id == credential_id,
  975. ProviderModelCredential.tenant_id == self.tenant_id,
  976. ProviderModelCredential.provider_name == self.provider.provider,
  977. ProviderModelCredential.model_name == model,
  978. ProviderModelCredential.model_type == model_type.to_origin_model_type(),
  979. )
  980. credential_record = session.execute(stmt).scalar_one_or_none()
  981. if not credential_record:
  982. raise ValueError("Credential record not found.")
  983. # validate custom model config
  984. provider_model_record = self._get_custom_model_record(model_type=model_type, model=model, session=session)
  985. if not provider_model_record:
  986. # create provider model record
  987. provider_model_record = ProviderModel(
  988. tenant_id=self.tenant_id,
  989. provider_name=self.provider.provider,
  990. model_name=model,
  991. model_type=model_type.to_origin_model_type(),
  992. is_valid=True,
  993. credential_id=credential_id,
  994. )
  995. else:
  996. if provider_model_record.credential_id == credential_record.id:
  997. raise ValueError("Can't add same credential")
  998. provider_model_record.credential_id = credential_record.id
  999. provider_model_record.updated_at = naive_utc_now()
  1000. session.add(provider_model_record)
  1001. session.commit()
  1002. def switch_custom_model_credential(self, model_type: ModelType, model: str, credential_id: str):
  1003. """
  1004. switch the custom model credential.
  1005. :param model_type: model type
  1006. :param model: model name
  1007. :param credential_id: credential id
  1008. :return:
  1009. """
  1010. with Session(db.engine) as session:
  1011. stmt = select(ProviderModelCredential).where(
  1012. ProviderModelCredential.id == credential_id,
  1013. ProviderModelCredential.tenant_id == self.tenant_id,
  1014. ProviderModelCredential.provider_name == self.provider.provider,
  1015. ProviderModelCredential.model_name == model,
  1016. ProviderModelCredential.model_type == model_type.to_origin_model_type(),
  1017. )
  1018. credential_record = session.execute(stmt).scalar_one_or_none()
  1019. if not credential_record:
  1020. raise ValueError("Credential record not found.")
  1021. provider_model_record = self._get_custom_model_record(model_type=model_type, model=model, session=session)
  1022. if not provider_model_record:
  1023. raise ValueError("The custom model record not found.")
  1024. provider_model_record.credential_id = credential_record.id
  1025. provider_model_record.updated_at = naive_utc_now()
  1026. session.add(provider_model_record)
  1027. session.commit()
  1028. def delete_custom_model(self, model_type: ModelType, model: str):
  1029. """
  1030. Delete custom model.
  1031. :param model_type: model type
  1032. :param model: model name
  1033. :return:
  1034. """
  1035. with Session(db.engine) as session:
  1036. # get provider model
  1037. provider_model_record = self._get_custom_model_record(model_type=model_type, model=model, session=session)
  1038. # delete provider model
  1039. if provider_model_record:
  1040. session.delete(provider_model_record)
  1041. session.commit()
  1042. provider_model_credentials_cache = ProviderCredentialsCache(
  1043. tenant_id=self.tenant_id,
  1044. identity_id=provider_model_record.id,
  1045. cache_type=ProviderCredentialsCacheType.MODEL,
  1046. )
  1047. provider_model_credentials_cache.delete()
  1048. def _get_provider_model_setting(
  1049. self, model_type: ModelType, model: str, session: Session
  1050. ) -> ProviderModelSetting | None:
  1051. """
  1052. Get provider model setting.
  1053. """
  1054. model_provider_id = ModelProviderID(self.provider.provider)
  1055. provider_names = [self.provider.provider]
  1056. if model_provider_id.is_langgenius():
  1057. provider_names.append(model_provider_id.provider_name)
  1058. stmt = select(ProviderModelSetting).where(
  1059. ProviderModelSetting.tenant_id == self.tenant_id,
  1060. ProviderModelSetting.provider_name.in_(provider_names),
  1061. ProviderModelSetting.model_type == model_type.to_origin_model_type(),
  1062. ProviderModelSetting.model_name == model,
  1063. )
  1064. return session.execute(stmt).scalars().first()
  1065. def enable_model(self, model_type: ModelType, model: str) -> ProviderModelSetting:
  1066. """
  1067. Enable model.
  1068. :param model_type: model type
  1069. :param model: model name
  1070. :return:
  1071. """
  1072. with Session(db.engine) as session:
  1073. model_setting = self._get_provider_model_setting(model_type=model_type, model=model, session=session)
  1074. if model_setting:
  1075. model_setting.enabled = True
  1076. model_setting.updated_at = naive_utc_now()
  1077. else:
  1078. model_setting = ProviderModelSetting(
  1079. tenant_id=self.tenant_id,
  1080. provider_name=self.provider.provider,
  1081. model_type=model_type.to_origin_model_type(),
  1082. model_name=model,
  1083. enabled=True,
  1084. )
  1085. session.add(model_setting)
  1086. session.commit()
  1087. return model_setting
  1088. def disable_model(self, model_type: ModelType, model: str) -> ProviderModelSetting:
  1089. """
  1090. Disable model.
  1091. :param model_type: model type
  1092. :param model: model name
  1093. :return:
  1094. """
  1095. with Session(db.engine) as session:
  1096. model_setting = self._get_provider_model_setting(model_type=model_type, model=model, session=session)
  1097. if model_setting:
  1098. model_setting.enabled = False
  1099. model_setting.updated_at = naive_utc_now()
  1100. else:
  1101. model_setting = ProviderModelSetting(
  1102. tenant_id=self.tenant_id,
  1103. provider_name=self.provider.provider,
  1104. model_type=model_type.to_origin_model_type(),
  1105. model_name=model,
  1106. enabled=False,
  1107. )
  1108. session.add(model_setting)
  1109. session.commit()
  1110. return model_setting
  1111. def get_provider_model_setting(self, model_type: ModelType, model: str) -> Optional[ProviderModelSetting]:
  1112. """
  1113. Get provider model setting.
  1114. :param model_type: model type
  1115. :param model: model name
  1116. :return:
  1117. """
  1118. with Session(db.engine) as session:
  1119. return self._get_provider_model_setting(model_type=model_type, model=model, session=session)
  1120. def enable_model_load_balancing(self, model_type: ModelType, model: str) -> ProviderModelSetting:
  1121. """
  1122. Enable model load balancing.
  1123. :param model_type: model type
  1124. :param model: model name
  1125. :return:
  1126. """
  1127. model_provider_id = ModelProviderID(self.provider.provider)
  1128. provider_names = [self.provider.provider]
  1129. if model_provider_id.is_langgenius():
  1130. provider_names.append(model_provider_id.provider_name)
  1131. with Session(db.engine) as session:
  1132. stmt = select(func.count(LoadBalancingModelConfig.id)).where(
  1133. LoadBalancingModelConfig.tenant_id == self.tenant_id,
  1134. LoadBalancingModelConfig.provider_name.in_(provider_names),
  1135. LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
  1136. LoadBalancingModelConfig.model_name == model,
  1137. )
  1138. load_balancing_config_count = session.execute(stmt).scalar() or 0
  1139. if load_balancing_config_count <= 1:
  1140. raise ValueError("Model load balancing configuration must be more than 1.")
  1141. model_setting = self._get_provider_model_setting(model_type=model_type, model=model, session=session)
  1142. if model_setting:
  1143. model_setting.load_balancing_enabled = True
  1144. model_setting.updated_at = naive_utc_now()
  1145. else:
  1146. model_setting = ProviderModelSetting(
  1147. tenant_id=self.tenant_id,
  1148. provider_name=self.provider.provider,
  1149. model_type=model_type.to_origin_model_type(),
  1150. model_name=model,
  1151. load_balancing_enabled=True,
  1152. )
  1153. session.add(model_setting)
  1154. session.commit()
  1155. return model_setting
  1156. def disable_model_load_balancing(self, model_type: ModelType, model: str) -> ProviderModelSetting:
  1157. """
  1158. Disable model load balancing.
  1159. :param model_type: model type
  1160. :param model: model name
  1161. :return:
  1162. """
  1163. with Session(db.engine) as session:
  1164. model_setting = self._get_provider_model_setting(model_type=model_type, model=model, session=session)
  1165. if model_setting:
  1166. model_setting.load_balancing_enabled = False
  1167. model_setting.updated_at = naive_utc_now()
  1168. else:
  1169. model_setting = ProviderModelSetting(
  1170. tenant_id=self.tenant_id,
  1171. provider_name=self.provider.provider,
  1172. model_type=model_type.to_origin_model_type(),
  1173. model_name=model,
  1174. load_balancing_enabled=False,
  1175. )
  1176. session.add(model_setting)
  1177. session.commit()
  1178. return model_setting
  1179. def get_model_type_instance(self, model_type: ModelType) -> AIModel:
  1180. """
  1181. Get current model type instance.
  1182. :param model_type: model type
  1183. :return:
  1184. """
  1185. model_provider_factory = ModelProviderFactory(self.tenant_id)
  1186. # Get model instance of LLM
  1187. return model_provider_factory.get_model_type_instance(provider=self.provider.provider, model_type=model_type)
  1188. def get_model_schema(self, model_type: ModelType, model: str, credentials: dict | None) -> AIModelEntity | None:
  1189. """
  1190. Get model schema
  1191. """
  1192. model_provider_factory = ModelProviderFactory(self.tenant_id)
  1193. return model_provider_factory.get_model_schema(
  1194. provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials
  1195. )
  1196. def switch_preferred_provider_type(self, provider_type: ProviderType, session: Session | None = None):
  1197. """
  1198. Switch preferred provider type.
  1199. :param provider_type:
  1200. :return:
  1201. """
  1202. if provider_type == self.preferred_provider_type:
  1203. return
  1204. if provider_type == ProviderType.SYSTEM and not self.system_configuration.enabled:
  1205. return
  1206. def _switch(s: Session):
  1207. # get preferred provider
  1208. model_provider_id = ModelProviderID(self.provider.provider)
  1209. provider_names = [self.provider.provider]
  1210. if model_provider_id.is_langgenius():
  1211. provider_names.append(model_provider_id.provider_name)
  1212. stmt = select(TenantPreferredModelProvider).where(
  1213. TenantPreferredModelProvider.tenant_id == self.tenant_id,
  1214. TenantPreferredModelProvider.provider_name.in_(provider_names),
  1215. )
  1216. preferred_model_provider = s.execute(stmt).scalars().first()
  1217. if preferred_model_provider:
  1218. preferred_model_provider.preferred_provider_type = provider_type.value
  1219. else:
  1220. preferred_model_provider = TenantPreferredModelProvider(
  1221. tenant_id=self.tenant_id,
  1222. provider_name=self.provider.provider,
  1223. preferred_provider_type=provider_type.value,
  1224. )
  1225. s.add(preferred_model_provider)
  1226. s.commit()
  1227. if session:
  1228. return _switch(session)
  1229. else:
  1230. with Session(db.engine) as session:
  1231. return _switch(session)
  1232. def extract_secret_variables(self, credential_form_schemas: list[CredentialFormSchema]) -> list[str]:
  1233. """
  1234. Extract secret input form variables.
  1235. :param credential_form_schemas:
  1236. :return:
  1237. """
  1238. secret_input_form_variables = []
  1239. for credential_form_schema in credential_form_schemas:
  1240. if credential_form_schema.type == FormType.SECRET_INPUT:
  1241. secret_input_form_variables.append(credential_form_schema.variable)
  1242. return secret_input_form_variables
  1243. def obfuscated_credentials(self, credentials: dict, credential_form_schemas: list[CredentialFormSchema]):
  1244. """
  1245. Obfuscated credentials.
  1246. :param credentials: credentials
  1247. :param credential_form_schemas: credential form schemas
  1248. :return:
  1249. """
  1250. # Get provider credential secret variables
  1251. credential_secret_variables = self.extract_secret_variables(credential_form_schemas)
  1252. # Obfuscate provider credentials
  1253. copy_credentials = credentials.copy()
  1254. for key, value in copy_credentials.items():
  1255. if key in credential_secret_variables:
  1256. copy_credentials[key] = encrypter.obfuscated_token(value)
  1257. return copy_credentials
  1258. def get_provider_model(
  1259. self, model_type: ModelType, model: str, only_active: bool = False
  1260. ) -> Optional[ModelWithProviderEntity]:
  1261. """
  1262. Get provider model.
  1263. :param model_type: model type
  1264. :param model: model name
  1265. :param only_active: return active model only
  1266. :return:
  1267. """
  1268. provider_models = self.get_provider_models(model_type, only_active, model)
  1269. for provider_model in provider_models:
  1270. if provider_model.model == model:
  1271. return provider_model
  1272. return None
  1273. def get_provider_models(
  1274. self, model_type: Optional[ModelType] = None, only_active: bool = False, model: Optional[str] = None
  1275. ) -> list[ModelWithProviderEntity]:
  1276. """
  1277. Get provider models.
  1278. :param model_type: model type
  1279. :param only_active: only active models
  1280. :param model: model name
  1281. :return:
  1282. """
  1283. model_provider_factory = ModelProviderFactory(self.tenant_id)
  1284. provider_schema = model_provider_factory.get_provider_schema(self.provider.provider)
  1285. model_types: list[ModelType] = []
  1286. if model_type:
  1287. model_types.append(model_type)
  1288. else:
  1289. model_types = list(provider_schema.supported_model_types)
  1290. # Group model settings by model type and model
  1291. model_setting_map: defaultdict[ModelType, dict[str, ModelSettings]] = defaultdict(dict)
  1292. for model_setting in self.model_settings:
  1293. model_setting_map[model_setting.model_type][model_setting.model] = model_setting
  1294. if self.using_provider_type == ProviderType.SYSTEM:
  1295. provider_models = self._get_system_provider_models(
  1296. model_types=model_types, provider_schema=provider_schema, model_setting_map=model_setting_map
  1297. )
  1298. else:
  1299. provider_models = self._get_custom_provider_models(
  1300. model_types=model_types,
  1301. provider_schema=provider_schema,
  1302. model_setting_map=model_setting_map,
  1303. model=model,
  1304. )
  1305. if only_active:
  1306. provider_models = [m for m in provider_models if m.status == ModelStatus.ACTIVE]
  1307. # resort provider_models
  1308. # Optimize sorting logic: first sort by provider.position order, then by model_type.value
  1309. # Get the position list for model types (retrieve only once for better performance)
  1310. model_type_positions = {}
  1311. if hasattr(self.provider, "position") and self.provider.position:
  1312. model_type_positions = self.provider.position
  1313. def get_sort_key(model: ModelWithProviderEntity):
  1314. # Get the position list for the current model type
  1315. positions = model_type_positions.get(model.model_type.value, [])
  1316. # If the model name is in the position list, use its index for sorting
  1317. # Otherwise use a large value (list length) to place undefined models at the end
  1318. position_index = positions.index(model.model) if model.model in positions else len(positions)
  1319. # Return composite sort key: (model_type value, model position index)
  1320. return (model.model_type.value, position_index)
  1321. # Sort using the composite sort key
  1322. return sorted(provider_models, key=get_sort_key)
  1323. def _get_system_provider_models(
  1324. self,
  1325. model_types: Sequence[ModelType],
  1326. provider_schema: ProviderEntity,
  1327. model_setting_map: dict[ModelType, dict[str, ModelSettings]],
  1328. ) -> list[ModelWithProviderEntity]:
  1329. """
  1330. Get system provider models.
  1331. :param model_types: model types
  1332. :param provider_schema: provider schema
  1333. :param model_setting_map: model setting map
  1334. :return:
  1335. """
  1336. provider_models = []
  1337. for model_type in model_types:
  1338. for m in provider_schema.models:
  1339. if m.model_type != model_type:
  1340. continue
  1341. status = ModelStatus.ACTIVE
  1342. if m.model_type in model_setting_map and m.model in model_setting_map[m.model_type]:
  1343. model_setting = model_setting_map[m.model_type][m.model]
  1344. if model_setting.enabled is False:
  1345. status = ModelStatus.DISABLED
  1346. provider_models.append(
  1347. ModelWithProviderEntity(
  1348. model=m.model,
  1349. label=m.label,
  1350. model_type=m.model_type,
  1351. features=m.features,
  1352. fetch_from=m.fetch_from,
  1353. model_properties=m.model_properties,
  1354. deprecated=m.deprecated,
  1355. provider=SimpleModelProviderEntity(self.provider),
  1356. status=status,
  1357. )
  1358. )
  1359. if self.provider.provider not in original_provider_configurate_methods:
  1360. original_provider_configurate_methods[self.provider.provider] = []
  1361. for configurate_method in provider_schema.configurate_methods:
  1362. original_provider_configurate_methods[self.provider.provider].append(configurate_method)
  1363. should_use_custom_model = False
  1364. if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]:
  1365. should_use_custom_model = True
  1366. for quota_configuration in self.system_configuration.quota_configurations:
  1367. if self.system_configuration.current_quota_type != quota_configuration.quota_type:
  1368. continue
  1369. restrict_models = quota_configuration.restrict_models
  1370. if len(restrict_models) == 0:
  1371. break
  1372. if should_use_custom_model:
  1373. if original_provider_configurate_methods[self.provider.provider] == [
  1374. ConfigurateMethod.CUSTOMIZABLE_MODEL
  1375. ]:
  1376. # only customizable model
  1377. for restrict_model in restrict_models:
  1378. copy_credentials = (
  1379. self.system_configuration.credentials.copy()
  1380. if self.system_configuration.credentials
  1381. else {}
  1382. )
  1383. if restrict_model.base_model_name:
  1384. copy_credentials["base_model_name"] = restrict_model.base_model_name
  1385. try:
  1386. custom_model_schema = self.get_model_schema(
  1387. model_type=restrict_model.model_type,
  1388. model=restrict_model.model,
  1389. credentials=copy_credentials,
  1390. )
  1391. except Exception as ex:
  1392. logger.warning("get custom model schema failed, %s", ex)
  1393. continue
  1394. if not custom_model_schema:
  1395. continue
  1396. if custom_model_schema.model_type not in model_types:
  1397. continue
  1398. status = ModelStatus.ACTIVE
  1399. if (
  1400. custom_model_schema.model_type in model_setting_map
  1401. and custom_model_schema.model in model_setting_map[custom_model_schema.model_type]
  1402. ):
  1403. model_setting = model_setting_map[custom_model_schema.model_type][custom_model_schema.model]
  1404. if model_setting.enabled is False:
  1405. status = ModelStatus.DISABLED
  1406. provider_models.append(
  1407. ModelWithProviderEntity(
  1408. model=custom_model_schema.model,
  1409. label=custom_model_schema.label,
  1410. model_type=custom_model_schema.model_type,
  1411. features=custom_model_schema.features,
  1412. fetch_from=FetchFrom.PREDEFINED_MODEL,
  1413. model_properties=custom_model_schema.model_properties,
  1414. deprecated=custom_model_schema.deprecated,
  1415. provider=SimpleModelProviderEntity(self.provider),
  1416. status=status,
  1417. )
  1418. )
  1419. # if llm name not in restricted llm list, remove it
  1420. restrict_model_names = [rm.model for rm in restrict_models]
  1421. for model in provider_models:
  1422. if model.model_type == ModelType.LLM and model.model not in restrict_model_names:
  1423. model.status = ModelStatus.NO_PERMISSION
  1424. elif not quota_configuration.is_valid:
  1425. model.status = ModelStatus.QUOTA_EXCEEDED
  1426. return provider_models
  1427. def _get_custom_provider_models(
  1428. self,
  1429. model_types: Sequence[ModelType],
  1430. provider_schema: ProviderEntity,
  1431. model_setting_map: dict[ModelType, dict[str, ModelSettings]],
  1432. model: Optional[str] = None,
  1433. ) -> list[ModelWithProviderEntity]:
  1434. """
  1435. Get custom provider models.
  1436. :param model_types: model types
  1437. :param provider_schema: provider schema
  1438. :param model_setting_map: model setting map
  1439. :return:
  1440. """
  1441. provider_models = []
  1442. credentials = None
  1443. if self.custom_configuration.provider:
  1444. credentials = self.custom_configuration.provider.credentials
  1445. for model_type in model_types:
  1446. if model_type not in self.provider.supported_model_types:
  1447. continue
  1448. for m in provider_schema.models:
  1449. if m.model_type != model_type:
  1450. continue
  1451. status = ModelStatus.ACTIVE if credentials else ModelStatus.NO_CONFIGURE
  1452. load_balancing_enabled = False
  1453. has_invalid_load_balancing_configs = False
  1454. if m.model_type in model_setting_map and m.model in model_setting_map[m.model_type]:
  1455. model_setting = model_setting_map[m.model_type][m.model]
  1456. if model_setting.enabled is False:
  1457. status = ModelStatus.DISABLED
  1458. provider_model_lb_configs = [
  1459. config
  1460. for config in model_setting.load_balancing_configs
  1461. if config.credential_source_type != "custom_model"
  1462. ]
  1463. load_balancing_enabled = model_setting.load_balancing_enabled
  1464. # when the user enable load_balancing but available configs are less than 2 display warning
  1465. has_invalid_load_balancing_configs = load_balancing_enabled and len(provider_model_lb_configs) < 2
  1466. provider_models.append(
  1467. ModelWithProviderEntity(
  1468. model=m.model,
  1469. label=m.label,
  1470. model_type=m.model_type,
  1471. features=m.features,
  1472. fetch_from=m.fetch_from,
  1473. model_properties=m.model_properties,
  1474. deprecated=m.deprecated,
  1475. provider=SimpleModelProviderEntity(self.provider),
  1476. status=status,
  1477. load_balancing_enabled=load_balancing_enabled,
  1478. has_invalid_load_balancing_configs=has_invalid_load_balancing_configs,
  1479. )
  1480. )
  1481. # custom models
  1482. for model_configuration in self.custom_configuration.models:
  1483. if model_configuration.model_type not in model_types:
  1484. continue
  1485. if model_configuration.unadded_to_model_list:
  1486. continue
  1487. if model and model != model_configuration.model:
  1488. continue
  1489. try:
  1490. custom_model_schema = self.get_model_schema(
  1491. model_type=model_configuration.model_type,
  1492. model=model_configuration.model,
  1493. credentials=model_configuration.credentials,
  1494. )
  1495. except Exception as ex:
  1496. logger.warning("get custom model schema failed, %s", ex)
  1497. continue
  1498. if not custom_model_schema:
  1499. continue
  1500. status = ModelStatus.ACTIVE
  1501. load_balancing_enabled = False
  1502. has_invalid_load_balancing_configs = False
  1503. if (
  1504. custom_model_schema.model_type in model_setting_map
  1505. and custom_model_schema.model in model_setting_map[custom_model_schema.model_type]
  1506. ):
  1507. model_setting = model_setting_map[custom_model_schema.model_type][custom_model_schema.model]
  1508. if model_setting.enabled is False:
  1509. status = ModelStatus.DISABLED
  1510. custom_model_lb_configs = [
  1511. config
  1512. for config in model_setting.load_balancing_configs
  1513. if config.credential_source_type != "provider"
  1514. ]
  1515. load_balancing_enabled = model_setting.load_balancing_enabled
  1516. # when the user enable load_balancing but available configs are less than 2 display warning
  1517. has_invalid_load_balancing_configs = load_balancing_enabled and len(custom_model_lb_configs) < 2
  1518. if len(model_configuration.available_model_credentials) > 0 and not model_configuration.credentials:
  1519. status = ModelStatus.CREDENTIAL_REMOVED
  1520. provider_models.append(
  1521. ModelWithProviderEntity(
  1522. model=custom_model_schema.model,
  1523. label=custom_model_schema.label,
  1524. model_type=custom_model_schema.model_type,
  1525. features=custom_model_schema.features,
  1526. fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
  1527. model_properties=custom_model_schema.model_properties,
  1528. deprecated=custom_model_schema.deprecated,
  1529. provider=SimpleModelProviderEntity(self.provider),
  1530. status=status,
  1531. load_balancing_enabled=load_balancing_enabled,
  1532. has_invalid_load_balancing_configs=has_invalid_load_balancing_configs,
  1533. )
  1534. )
  1535. return provider_models
  1536. class ProviderConfigurations(BaseModel):
  1537. """
  1538. Model class for provider configuration dict.
  1539. """
  1540. tenant_id: str
  1541. configurations: dict[str, ProviderConfiguration] = Field(default_factory=dict)
  1542. def __init__(self, tenant_id: str):
  1543. super().__init__(tenant_id=tenant_id)
  1544. def get_models(
  1545. self, provider: Optional[str] = None, model_type: Optional[ModelType] = None, only_active: bool = False
  1546. ) -> list[ModelWithProviderEntity]:
  1547. """
  1548. Get available models.
  1549. If preferred provider type is `system`:
  1550. Get the current **system mode** if provider supported,
  1551. if all system modes are not available (no quota), it is considered to be the **custom credential mode**.
  1552. If there is no model configured in custom mode, it is treated as no_configure.
  1553. system > custom > no_configure
  1554. If preferred provider type is `custom`:
  1555. If custom credentials are configured, it is treated as custom mode.
  1556. Otherwise, get the current **system mode** if supported,
  1557. If all system modes are not available (no quota), it is treated as no_configure.
  1558. custom > system > no_configure
  1559. If real mode is `system`, use system credentials to get models,
  1560. paid quotas > provider free quotas > system free quotas
  1561. include pre-defined models (exclude GPT-4, status marked as `no_permission`).
  1562. If real mode is `custom`, use workspace custom credentials to get models,
  1563. include pre-defined models, custom models(manual append).
  1564. If real mode is `no_configure`, only return pre-defined models from `model runtime`.
  1565. (model status marked as `no_configure` if preferred provider type is `custom` otherwise `quota_exceeded`)
  1566. model status marked as `active` is available.
  1567. :param provider: provider name
  1568. :param model_type: model type
  1569. :param only_active: only active models
  1570. :return:
  1571. """
  1572. all_models = []
  1573. for provider_configuration in self.values():
  1574. if provider and provider_configuration.provider.provider != provider:
  1575. continue
  1576. all_models.extend(provider_configuration.get_provider_models(model_type, only_active))
  1577. return all_models
  1578. def to_list(self) -> list[ProviderConfiguration]:
  1579. """
  1580. Convert to list.
  1581. :return:
  1582. """
  1583. return list(self.values())
  1584. def __getitem__(self, key):
  1585. if "/" not in key:
  1586. key = str(ModelProviderID(key))
  1587. return self.configurations[key]
  1588. def __setitem__(self, key, value):
  1589. self.configurations[key] = value
  1590. def __contains__(self, key):
  1591. if "/" not in key:
  1592. key = str(ModelProviderID(key))
  1593. return key in self.configurations
  1594. def __iter__(self):
  1595. # Return an iterator of (key, value) tuples to match BaseModel's __iter__
  1596. yield from self.configurations.items()
  1597. def values(self) -> Iterator[ProviderConfiguration]:
  1598. return iter(self.configurations.values())
  1599. def get(self, key, default=None) -> ProviderConfiguration | None:
  1600. if "/" not in key:
  1601. key = str(ModelProviderID(key))
  1602. return self.configurations.get(key, default) # type: ignore
  1603. class ProviderModelBundle(BaseModel):
  1604. """
  1605. Provider model bundle.
  1606. """
  1607. configuration: ProviderConfiguration
  1608. model_type_instance: AIModel
  1609. # pydantic configs
  1610. model_config = ConfigDict(arbitrary_types_allowed=True, protected_namespaces=())