You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

provider_configuration.py 77KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844
  1. import json
  2. import logging
  3. import re
  4. from collections import defaultdict
  5. from collections.abc import Iterator, Sequence
  6. from json import JSONDecodeError
  7. from typing import Optional
  8. from pydantic import BaseModel, ConfigDict, Field
  9. from sqlalchemy import func, select
  10. from sqlalchemy.orm import Session
  11. from constants import HIDDEN_VALUE
  12. from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, SimpleModelProviderEntity
  13. from core.entities.provider_entities import (
  14. CustomConfiguration,
  15. ModelSettings,
  16. SystemConfiguration,
  17. SystemConfigurationStatus,
  18. )
  19. from core.helper import encrypter
  20. from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType
  21. from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
  22. from core.model_runtime.entities.provider_entities import (
  23. ConfigurateMethod,
  24. CredentialFormSchema,
  25. FormType,
  26. ProviderEntity,
  27. )
  28. from core.model_runtime.model_providers.__base.ai_model import AIModel
  29. from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
  30. from extensions.ext_database import db
  31. from libs.datetime_utils import naive_utc_now
  32. from models.provider import (
  33. LoadBalancingModelConfig,
  34. Provider,
  35. ProviderCredential,
  36. ProviderModel,
  37. ProviderModelCredential,
  38. ProviderModelSetting,
  39. ProviderType,
  40. TenantPreferredModelProvider,
  41. )
  42. from models.provider_ids import ModelProviderID
  43. logger = logging.getLogger(__name__)
  44. original_provider_configurate_methods: dict[str, list[ConfigurateMethod]] = {}
  45. class ProviderConfiguration(BaseModel):
  46. """
  47. Provider configuration entity for managing model provider settings.
  48. This class handles:
  49. - Provider credentials CRUD and switch
  50. - Custom Model credentials CRUD and switch
  51. - System vs custom provider switching
  52. - Load balancing configurations
  53. - Model enablement/disablement
  54. TODO: lots of logic in a BaseModel entity should be separated, the exceptions should be classified
  55. """
  56. tenant_id: str
  57. provider: ProviderEntity
  58. preferred_provider_type: ProviderType
  59. using_provider_type: ProviderType
  60. system_configuration: SystemConfiguration
  61. custom_configuration: CustomConfiguration
  62. model_settings: list[ModelSettings]
  63. # pydantic configs
  64. model_config = ConfigDict(protected_namespaces=())
  65. def __init__(self, **data):
  66. super().__init__(**data)
  67. if self.provider.provider not in original_provider_configurate_methods:
  68. original_provider_configurate_methods[self.provider.provider] = []
  69. for configurate_method in self.provider.configurate_methods:
  70. original_provider_configurate_methods[self.provider.provider].append(configurate_method)
  71. if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]:
  72. if (
  73. any(
  74. len(quota_configuration.restrict_models) > 0
  75. for quota_configuration in self.system_configuration.quota_configurations
  76. )
  77. and ConfigurateMethod.PREDEFINED_MODEL not in self.provider.configurate_methods
  78. ):
  79. self.provider.configurate_methods.append(ConfigurateMethod.PREDEFINED_MODEL)
  80. def get_current_credentials(self, model_type: ModelType, model: str) -> Optional[dict]:
  81. """
  82. Get current credentials.
  83. :param model_type: model type
  84. :param model: model name
  85. :return:
  86. """
  87. if self.model_settings:
  88. # check if model is disabled by admin
  89. for model_setting in self.model_settings:
  90. if model_setting.model_type == model_type and model_setting.model == model:
  91. if not model_setting.enabled:
  92. raise ValueError(f"Model {model} is disabled.")
  93. if self.using_provider_type == ProviderType.SYSTEM:
  94. restrict_models = []
  95. for quota_configuration in self.system_configuration.quota_configurations:
  96. if self.system_configuration.current_quota_type != quota_configuration.quota_type:
  97. continue
  98. restrict_models = quota_configuration.restrict_models
  99. copy_credentials = (
  100. self.system_configuration.credentials.copy() if self.system_configuration.credentials else {}
  101. )
  102. if restrict_models:
  103. for restrict_model in restrict_models:
  104. if (
  105. restrict_model.model_type == model_type
  106. and restrict_model.model == model
  107. and restrict_model.base_model_name
  108. ):
  109. copy_credentials["base_model_name"] = restrict_model.base_model_name
  110. return copy_credentials
  111. else:
  112. credentials = None
  113. if self.custom_configuration.models:
  114. for model_configuration in self.custom_configuration.models:
  115. if model_configuration.model_type == model_type and model_configuration.model == model:
  116. credentials = model_configuration.credentials
  117. break
  118. if not credentials and self.custom_configuration.provider:
  119. credentials = self.custom_configuration.provider.credentials
  120. return credentials
  121. def get_system_configuration_status(self) -> Optional[SystemConfigurationStatus]:
  122. """
  123. Get system configuration status.
  124. :return:
  125. """
  126. if self.system_configuration.enabled is False:
  127. return SystemConfigurationStatus.UNSUPPORTED
  128. current_quota_type = self.system_configuration.current_quota_type
  129. current_quota_configuration = next(
  130. (q for q in self.system_configuration.quota_configurations if q.quota_type == current_quota_type), None
  131. )
  132. if current_quota_configuration is None:
  133. return None
  134. if not current_quota_configuration:
  135. return SystemConfigurationStatus.UNSUPPORTED
  136. return (
  137. SystemConfigurationStatus.ACTIVE
  138. if current_quota_configuration.is_valid
  139. else SystemConfigurationStatus.QUOTA_EXCEEDED
  140. )
  141. def is_custom_configuration_available(self) -> bool:
  142. """
  143. Check custom configuration available.
  144. :return:
  145. """
  146. has_provider_credentials = (
  147. self.custom_configuration.provider is not None
  148. and len(self.custom_configuration.provider.available_credentials) > 0
  149. )
  150. has_model_configurations = len(self.custom_configuration.models) > 0
  151. return has_provider_credentials or has_model_configurations
  152. def _get_provider_record(self, session: Session) -> Provider | None:
  153. """
  154. Get custom provider record.
  155. """
  156. # get provider
  157. model_provider_id = ModelProviderID(self.provider.provider)
  158. provider_names = [self.provider.provider]
  159. if model_provider_id.is_langgenius():
  160. provider_names.append(model_provider_id.provider_name)
  161. stmt = select(Provider).where(
  162. Provider.tenant_id == self.tenant_id,
  163. Provider.provider_type == ProviderType.CUSTOM.value,
  164. Provider.provider_name.in_(provider_names),
  165. )
  166. return session.execute(stmt).scalar_one_or_none()
  167. def _get_specific_provider_credential(self, credential_id: str) -> dict | None:
  168. """
  169. Get a specific provider credential by ID.
  170. :param credential_id: Credential ID
  171. :return:
  172. """
  173. # Extract secret variables from provider credential schema
  174. credential_secret_variables = self.extract_secret_variables(
  175. self.provider.provider_credential_schema.credential_form_schemas
  176. if self.provider.provider_credential_schema
  177. else []
  178. )
  179. with Session(db.engine) as session:
  180. # Prefer the actual provider record name if exists (to handle aliased provider names)
  181. provider_record = self._get_provider_record(session)
  182. provider_name = provider_record.provider_name if provider_record else self.provider.provider
  183. stmt = select(ProviderCredential).where(
  184. ProviderCredential.id == credential_id,
  185. ProviderCredential.tenant_id == self.tenant_id,
  186. ProviderCredential.provider_name == provider_name,
  187. )
  188. credential = session.execute(stmt).scalar_one_or_none()
  189. if not credential or not credential.encrypted_config:
  190. raise ValueError(f"Credential with id {credential_id} not found.")
  191. try:
  192. credentials = json.loads(credential.encrypted_config)
  193. except JSONDecodeError:
  194. credentials = {}
  195. # Decrypt secret variables
  196. for key in credential_secret_variables:
  197. if key in credentials and credentials[key] is not None:
  198. try:
  199. credentials[key] = encrypter.decrypt_token(tenant_id=self.tenant_id, token=credentials[key])
  200. except Exception:
  201. pass
  202. return self.obfuscated_credentials(
  203. credentials=credentials,
  204. credential_form_schemas=self.provider.provider_credential_schema.credential_form_schemas
  205. if self.provider.provider_credential_schema
  206. else [],
  207. )
  208. def _check_provider_credential_name_exists(
  209. self, credential_name: str, session: Session, exclude_id: str | None = None
  210. ) -> bool:
  211. """
  212. not allowed same name when create or update a credential
  213. """
  214. stmt = select(ProviderCredential.id).where(
  215. ProviderCredential.tenant_id == self.tenant_id,
  216. ProviderCredential.provider_name == self.provider.provider,
  217. ProviderCredential.credential_name == credential_name,
  218. )
  219. if exclude_id:
  220. stmt = stmt.where(ProviderCredential.id != exclude_id)
  221. return session.execute(stmt).scalar_one_or_none() is not None
  222. def get_provider_credential(self, credential_id: str | None = None) -> dict | None:
  223. """
  224. Get provider credentials.
  225. :param credential_id: if provided, return the specified credential
  226. :return:
  227. """
  228. if credential_id:
  229. return self._get_specific_provider_credential(credential_id)
  230. # Default behavior: return current active provider credentials
  231. credentials = self.custom_configuration.provider.credentials if self.custom_configuration.provider else {}
  232. return self.obfuscated_credentials(
  233. credentials=credentials,
  234. credential_form_schemas=self.provider.provider_credential_schema.credential_form_schemas
  235. if self.provider.provider_credential_schema
  236. else [],
  237. )
  238. def validate_provider_credentials(
  239. self, credentials: dict, credential_id: str = "", session: Session | None = None
  240. ) -> dict:
  241. """
  242. Validate custom credentials.
  243. :param credentials: provider credentials
  244. :param credential_id: (Optional)If provided, can use existing credential's hidden api key to validate
  245. :param session: optional database session
  246. :return:
  247. """
  248. def _validate(s: Session) -> dict:
  249. # Get provider credential secret variables
  250. provider_credential_secret_variables = self.extract_secret_variables(
  251. self.provider.provider_credential_schema.credential_form_schemas
  252. if self.provider.provider_credential_schema
  253. else []
  254. )
  255. if credential_id:
  256. try:
  257. stmt = select(ProviderCredential).where(
  258. ProviderCredential.tenant_id == self.tenant_id,
  259. ProviderCredential.provider_name == self.provider.provider,
  260. ProviderCredential.id == credential_id,
  261. )
  262. credential_record = s.execute(stmt).scalar_one_or_none()
  263. # fix origin data
  264. if credential_record and credential_record.encrypted_config:
  265. if not credential_record.encrypted_config.startswith("{"):
  266. original_credentials = {"openai_api_key": credential_record.encrypted_config}
  267. else:
  268. original_credentials = json.loads(credential_record.encrypted_config)
  269. else:
  270. original_credentials = {}
  271. except JSONDecodeError:
  272. original_credentials = {}
  273. # encrypt credentials
  274. for key, value in credentials.items():
  275. if key in provider_credential_secret_variables:
  276. # if send [__HIDDEN__] in secret input, it will be same as original value
  277. if value == HIDDEN_VALUE and key in original_credentials:
  278. credentials[key] = encrypter.decrypt_token(
  279. tenant_id=self.tenant_id, token=original_credentials[key]
  280. )
  281. model_provider_factory = ModelProviderFactory(self.tenant_id)
  282. validated_credentials = model_provider_factory.provider_credentials_validate(
  283. provider=self.provider.provider, credentials=credentials
  284. )
  285. for key, value in validated_credentials.items():
  286. if key in provider_credential_secret_variables:
  287. validated_credentials[key] = encrypter.encrypt_token(self.tenant_id, value)
  288. return validated_credentials
  289. if session:
  290. return _validate(session)
  291. else:
  292. with Session(db.engine) as new_session:
  293. return _validate(new_session)
  294. def _generate_provider_credential_name(self, session) -> str:
  295. """
  296. Generate a unique credential name for provider.
  297. :return: credential name
  298. """
  299. return self._generate_next_api_key_name(
  300. session=session,
  301. query_factory=lambda: select(ProviderCredential).where(
  302. ProviderCredential.tenant_id == self.tenant_id,
  303. ProviderCredential.provider_name == self.provider.provider,
  304. ),
  305. )
  306. def _generate_custom_model_credential_name(self, model: str, model_type: ModelType, session) -> str:
  307. """
  308. Generate a unique credential name for custom model.
  309. :return: credential name
  310. """
  311. return self._generate_next_api_key_name(
  312. session=session,
  313. query_factory=lambda: select(ProviderModelCredential).where(
  314. ProviderModelCredential.tenant_id == self.tenant_id,
  315. ProviderModelCredential.provider_name == self.provider.provider,
  316. ProviderModelCredential.model_name == model,
  317. ProviderModelCredential.model_type == model_type.to_origin_model_type(),
  318. ),
  319. )
  320. def _generate_next_api_key_name(self, session, query_factory) -> str:
  321. """
  322. Generate next available API KEY name by finding the highest numbered suffix.
  323. :param session: database session
  324. :param query_factory: function that returns the SQLAlchemy query
  325. :return: next available API KEY name
  326. """
  327. try:
  328. stmt = query_factory()
  329. credential_records = session.execute(stmt).scalars().all()
  330. if not credential_records:
  331. return "API KEY 1"
  332. # Extract numbers from API KEY pattern using list comprehension
  333. pattern = re.compile(r"^API KEY\s+(\d+)$")
  334. numbers = [
  335. int(match.group(1))
  336. for cr in credential_records
  337. if cr.credential_name and (match := pattern.match(cr.credential_name.strip()))
  338. ]
  339. # Return next sequential number
  340. next_number = max(numbers, default=0) + 1
  341. return f"API KEY {next_number}"
  342. except Exception as e:
  343. logger.warning("Error generating next credential name: %s", str(e))
  344. return "API KEY 1"
  345. def create_provider_credential(self, credentials: dict, credential_name: str | None) -> None:
  346. """
  347. Add custom provider credentials.
  348. :param credentials: provider credentials
  349. :param credential_name: credential name
  350. :return:
  351. """
  352. with Session(db.engine) as session:
  353. if credential_name:
  354. if self._check_provider_credential_name_exists(credential_name=credential_name, session=session):
  355. raise ValueError(f"Credential with name '{credential_name}' already exists.")
  356. else:
  357. credential_name = self._generate_provider_credential_name(session)
  358. credentials = self.validate_provider_credentials(credentials=credentials, session=session)
  359. provider_record = self._get_provider_record(session)
  360. try:
  361. new_record = ProviderCredential(
  362. tenant_id=self.tenant_id,
  363. provider_name=self.provider.provider,
  364. encrypted_config=json.dumps(credentials),
  365. credential_name=credential_name,
  366. )
  367. session.add(new_record)
  368. session.flush()
  369. if not provider_record:
  370. # If provider record does not exist, create it
  371. provider_record = Provider(
  372. tenant_id=self.tenant_id,
  373. provider_name=self.provider.provider,
  374. provider_type=ProviderType.CUSTOM.value,
  375. is_valid=True,
  376. credential_id=new_record.id,
  377. )
  378. session.add(provider_record)
  379. provider_model_credentials_cache = ProviderCredentialsCache(
  380. tenant_id=self.tenant_id,
  381. identity_id=provider_record.id,
  382. cache_type=ProviderCredentialsCacheType.PROVIDER,
  383. )
  384. provider_model_credentials_cache.delete()
  385. self.switch_preferred_provider_type(provider_type=ProviderType.CUSTOM, session=session)
  386. session.commit()
  387. except Exception:
  388. session.rollback()
  389. raise
  390. def update_provider_credential(
  391. self,
  392. credentials: dict,
  393. credential_id: str,
  394. credential_name: str | None,
  395. ) -> None:
  396. """
  397. update a saved provider credential (by credential_id).
  398. :param credentials: provider credentials
  399. :param credential_id: credential id
  400. :param credential_name: credential name
  401. :return:
  402. """
  403. with Session(db.engine) as session:
  404. if credential_name and self._check_provider_credential_name_exists(
  405. credential_name=credential_name, session=session, exclude_id=credential_id
  406. ):
  407. raise ValueError(f"Credential with name '{credential_name}' already exists.")
  408. credentials = self.validate_provider_credentials(
  409. credentials=credentials, credential_id=credential_id, session=session
  410. )
  411. provider_record = self._get_provider_record(session)
  412. stmt = select(ProviderCredential).where(
  413. ProviderCredential.id == credential_id,
  414. ProviderCredential.tenant_id == self.tenant_id,
  415. ProviderCredential.provider_name == self.provider.provider,
  416. )
  417. # Get the credential record to update
  418. credential_record = session.execute(stmt).scalar_one_or_none()
  419. if not credential_record:
  420. raise ValueError("Credential record not found.")
  421. try:
  422. # Update credential
  423. credential_record.encrypted_config = json.dumps(credentials)
  424. credential_record.updated_at = naive_utc_now()
  425. if credential_name:
  426. credential_record.credential_name = credential_name
  427. session.commit()
  428. if provider_record and provider_record.credential_id == credential_id:
  429. provider_model_credentials_cache = ProviderCredentialsCache(
  430. tenant_id=self.tenant_id,
  431. identity_id=provider_record.id,
  432. cache_type=ProviderCredentialsCacheType.PROVIDER,
  433. )
  434. provider_model_credentials_cache.delete()
  435. self._update_load_balancing_configs_with_credential(
  436. credential_id=credential_id,
  437. credential_record=credential_record,
  438. credential_source="provider",
  439. session=session,
  440. )
  441. except Exception:
  442. session.rollback()
  443. raise
  444. def _update_load_balancing_configs_with_credential(
  445. self,
  446. credential_id: str,
  447. credential_record: ProviderCredential | ProviderModelCredential,
  448. credential_source: str,
  449. session: Session,
  450. ) -> None:
  451. """
  452. Update load balancing configurations that reference the given credential_id.
  453. :param credential_id: credential id
  454. :param credential_record: the encrypted_config and credential_name
  455. :param credential_source: the credential comes from the provider_credential(`provider`)
  456. or the provider_model_credential(`custom_model`)
  457. :param session: the database session
  458. :return:
  459. """
  460. # Find all load balancing configs that use this credential_id
  461. stmt = select(LoadBalancingModelConfig).where(
  462. LoadBalancingModelConfig.tenant_id == self.tenant_id,
  463. LoadBalancingModelConfig.provider_name == self.provider.provider,
  464. LoadBalancingModelConfig.credential_id == credential_id,
  465. LoadBalancingModelConfig.credential_source_type == credential_source,
  466. )
  467. load_balancing_configs = session.execute(stmt).scalars().all()
  468. if not load_balancing_configs:
  469. return
  470. # Update each load balancing config with the new credentials
  471. for lb_config in load_balancing_configs:
  472. # Update the encrypted_config with the new credentials
  473. lb_config.encrypted_config = credential_record.encrypted_config
  474. lb_config.name = credential_record.credential_name
  475. lb_config.updated_at = naive_utc_now()
  476. # Clear cache for this load balancing config
  477. lb_credentials_cache = ProviderCredentialsCache(
  478. tenant_id=self.tenant_id,
  479. identity_id=lb_config.id,
  480. cache_type=ProviderCredentialsCacheType.LOAD_BALANCING_MODEL,
  481. )
  482. lb_credentials_cache.delete()
  483. session.commit()
  484. def delete_provider_credential(self, credential_id: str) -> None:
  485. """
  486. Delete a saved provider credential (by credential_id).
  487. :param credential_id: credential id
  488. :return:
  489. """
  490. with Session(db.engine) as session:
  491. stmt = select(ProviderCredential).where(
  492. ProviderCredential.id == credential_id,
  493. ProviderCredential.tenant_id == self.tenant_id,
  494. ProviderCredential.provider_name == self.provider.provider,
  495. )
  496. # Get the credential record to update
  497. credential_record = session.execute(stmt).scalar_one_or_none()
  498. if not credential_record:
  499. raise ValueError("Credential record not found.")
  500. # Check if this credential is used in load balancing configs
  501. lb_stmt = select(LoadBalancingModelConfig).where(
  502. LoadBalancingModelConfig.tenant_id == self.tenant_id,
  503. LoadBalancingModelConfig.provider_name == self.provider.provider,
  504. LoadBalancingModelConfig.credential_id == credential_id,
  505. LoadBalancingModelConfig.credential_source_type == "provider",
  506. )
  507. lb_configs_using_credential = session.execute(lb_stmt).scalars().all()
  508. try:
  509. for lb_config in lb_configs_using_credential:
  510. lb_credentials_cache = ProviderCredentialsCache(
  511. tenant_id=self.tenant_id,
  512. identity_id=lb_config.id,
  513. cache_type=ProviderCredentialsCacheType.LOAD_BALANCING_MODEL,
  514. )
  515. lb_credentials_cache.delete()
  516. session.delete(lb_config)
  517. # Check if this is the currently active credential
  518. provider_record = self._get_provider_record(session)
  519. # Check available credentials count BEFORE deleting
  520. # if this is the last credential, we need to delete the provider record
  521. count_stmt = select(func.count(ProviderCredential.id)).where(
  522. ProviderCredential.tenant_id == self.tenant_id,
  523. ProviderCredential.provider_name == self.provider.provider,
  524. )
  525. available_credentials_count = session.execute(count_stmt).scalar() or 0
  526. session.delete(credential_record)
  527. if provider_record and available_credentials_count <= 1:
  528. # If all credentials are deleted, delete the provider record
  529. session.delete(provider_record)
  530. provider_model_credentials_cache = ProviderCredentialsCache(
  531. tenant_id=self.tenant_id,
  532. identity_id=provider_record.id,
  533. cache_type=ProviderCredentialsCacheType.PROVIDER,
  534. )
  535. provider_model_credentials_cache.delete()
  536. self.switch_preferred_provider_type(provider_type=ProviderType.SYSTEM, session=session)
  537. elif provider_record and provider_record.credential_id == credential_id:
  538. provider_record.credential_id = None
  539. provider_record.updated_at = naive_utc_now()
  540. provider_model_credentials_cache = ProviderCredentialsCache(
  541. tenant_id=self.tenant_id,
  542. identity_id=provider_record.id,
  543. cache_type=ProviderCredentialsCacheType.PROVIDER,
  544. )
  545. provider_model_credentials_cache.delete()
  546. self.switch_preferred_provider_type(provider_type=ProviderType.SYSTEM, session=session)
  547. session.commit()
  548. except Exception:
  549. session.rollback()
  550. raise
  551. def switch_active_provider_credential(self, credential_id: str) -> None:
  552. """
  553. Switch active provider credential (copy the selected one into current active snapshot).
  554. :param credential_id: credential id
  555. :return:
  556. """
  557. with Session(db.engine) as session:
  558. stmt = select(ProviderCredential).where(
  559. ProviderCredential.id == credential_id,
  560. ProviderCredential.tenant_id == self.tenant_id,
  561. ProviderCredential.provider_name == self.provider.provider,
  562. )
  563. credential_record = session.execute(stmt).scalar_one_or_none()
  564. if not credential_record:
  565. raise ValueError("Credential record not found.")
  566. provider_record = self._get_provider_record(session)
  567. if not provider_record:
  568. raise ValueError("Provider record not found.")
  569. try:
  570. provider_record.credential_id = credential_record.id
  571. provider_record.updated_at = naive_utc_now()
  572. session.commit()
  573. provider_model_credentials_cache = ProviderCredentialsCache(
  574. tenant_id=self.tenant_id,
  575. identity_id=provider_record.id,
  576. cache_type=ProviderCredentialsCacheType.PROVIDER,
  577. )
  578. provider_model_credentials_cache.delete()
  579. self.switch_preferred_provider_type(ProviderType.CUSTOM, session=session)
  580. except Exception:
  581. session.rollback()
  582. raise
  583. def _get_custom_model_record(
  584. self,
  585. model_type: ModelType,
  586. model: str,
  587. session: Session,
  588. ) -> ProviderModel | None:
  589. """
  590. Get custom model credentials.
  591. """
  592. # get provider model
  593. model_provider_id = ModelProviderID(self.provider.provider)
  594. provider_names = [self.provider.provider]
  595. if model_provider_id.is_langgenius():
  596. provider_names.append(model_provider_id.provider_name)
  597. stmt = select(ProviderModel).where(
  598. ProviderModel.tenant_id == self.tenant_id,
  599. ProviderModel.provider_name.in_(provider_names),
  600. ProviderModel.model_name == model,
  601. ProviderModel.model_type == model_type.to_origin_model_type(),
  602. )
  603. return session.execute(stmt).scalar_one_or_none()
  604. def _get_specific_custom_model_credential(
  605. self, model_type: ModelType, model: str, credential_id: str
  606. ) -> dict | None:
  607. """
  608. Get a specific provider credential by ID.
  609. :param credential_id: Credential ID
  610. :return:
  611. """
  612. model_credential_secret_variables = self.extract_secret_variables(
  613. self.provider.model_credential_schema.credential_form_schemas
  614. if self.provider.model_credential_schema
  615. else []
  616. )
  617. with Session(db.engine) as session:
  618. stmt = select(ProviderModelCredential).where(
  619. ProviderModelCredential.id == credential_id,
  620. ProviderModelCredential.tenant_id == self.tenant_id,
  621. ProviderModelCredential.provider_name == self.provider.provider,
  622. ProviderModelCredential.model_name == model,
  623. ProviderModelCredential.model_type == model_type.to_origin_model_type(),
  624. )
  625. credential_record = session.execute(stmt).scalar_one_or_none()
  626. if not credential_record or not credential_record.encrypted_config:
  627. raise ValueError(f"Credential with id {credential_id} not found.")
  628. try:
  629. credentials = json.loads(credential_record.encrypted_config)
  630. except JSONDecodeError:
  631. credentials = {}
  632. # Decrypt secret variables
  633. for key in model_credential_secret_variables:
  634. if key in credentials and credentials[key] is not None:
  635. try:
  636. credentials[key] = encrypter.decrypt_token(tenant_id=self.tenant_id, token=credentials[key])
  637. except Exception:
  638. pass
  639. current_credential_id = credential_record.id
  640. current_credential_name = credential_record.credential_name
  641. credentials = self.obfuscated_credentials(
  642. credentials=credentials,
  643. credential_form_schemas=self.provider.model_credential_schema.credential_form_schemas
  644. if self.provider.model_credential_schema
  645. else [],
  646. )
  647. return {
  648. "current_credential_id": current_credential_id,
  649. "current_credential_name": current_credential_name,
  650. "credentials": credentials,
  651. }
  652. def _check_custom_model_credential_name_exists(
  653. self, model_type: ModelType, model: str, credential_name: str, session: Session, exclude_id: str | None = None
  654. ) -> bool:
  655. """
  656. not allowed same name when create or update a credential
  657. """
  658. stmt = select(ProviderModelCredential).where(
  659. ProviderModelCredential.tenant_id == self.tenant_id,
  660. ProviderModelCredential.provider_name == self.provider.provider,
  661. ProviderModelCredential.model_name == model,
  662. ProviderModelCredential.model_type == model_type.to_origin_model_type(),
  663. ProviderModelCredential.credential_name == credential_name,
  664. )
  665. if exclude_id:
  666. stmt = stmt.where(ProviderModelCredential.id != exclude_id)
  667. return session.execute(stmt).scalar_one_or_none() is not None
  668. def get_custom_model_credential(
  669. self, model_type: ModelType, model: str, credential_id: str | None
  670. ) -> Optional[dict]:
  671. """
  672. Get custom model credentials.
  673. :param model_type: model type
  674. :param model: model name
  675. :return:
  676. """
  677. # If credential_id is provided, return the specific credential
  678. if credential_id:
  679. return self._get_specific_custom_model_credential(
  680. model_type=model_type, model=model, credential_id=credential_id
  681. )
  682. for model_configuration in self.custom_configuration.models:
  683. if (
  684. model_configuration.model_type == model_type
  685. and model_configuration.model == model
  686. and model_configuration.credentials
  687. ):
  688. current_credential_id = model_configuration.current_credential_id
  689. current_credential_name = model_configuration.current_credential_name
  690. credentials = self.obfuscated_credentials(
  691. credentials=model_configuration.credentials,
  692. credential_form_schemas=self.provider.model_credential_schema.credential_form_schemas
  693. if self.provider.model_credential_schema
  694. else [],
  695. )
  696. return {
  697. "current_credential_id": current_credential_id,
  698. "current_credential_name": current_credential_name,
  699. "credentials": credentials,
  700. }
  701. return None
  702. def validate_custom_model_credentials(
  703. self,
  704. model_type: ModelType,
  705. model: str,
  706. credentials: dict,
  707. credential_id: str = "",
  708. session: Session | None = None,
  709. ) -> dict:
  710. """
  711. Validate custom model credentials.
  712. :param model_type: model type
  713. :param model: model name
  714. :param credentials: model credentials dict
  715. :param credential_id: (Optional)If provided, can use existing credential's hidden api key to validate
  716. :return:
  717. """
  718. def _validate(s: Session) -> dict:
  719. # Get provider credential secret variables
  720. provider_credential_secret_variables = self.extract_secret_variables(
  721. self.provider.model_credential_schema.credential_form_schemas
  722. if self.provider.model_credential_schema
  723. else []
  724. )
  725. if credential_id:
  726. try:
  727. stmt = select(ProviderModelCredential).where(
  728. ProviderModelCredential.id == credential_id,
  729. ProviderModelCredential.tenant_id == self.tenant_id,
  730. ProviderModelCredential.provider_name == self.provider.provider,
  731. ProviderModelCredential.model_name == model,
  732. ProviderModelCredential.model_type == model_type.to_origin_model_type(),
  733. )
  734. credential_record = s.execute(stmt).scalar_one_or_none()
  735. original_credentials = (
  736. json.loads(credential_record.encrypted_config)
  737. if credential_record and credential_record.encrypted_config
  738. else {}
  739. )
  740. except JSONDecodeError:
  741. original_credentials = {}
  742. # decrypt credentials
  743. for key, value in credentials.items():
  744. if key in provider_credential_secret_variables:
  745. # if send [__HIDDEN__] in secret input, it will be same as original value
  746. if value == HIDDEN_VALUE and key in original_credentials:
  747. credentials[key] = encrypter.decrypt_token(
  748. tenant_id=self.tenant_id, token=original_credentials[key]
  749. )
  750. model_provider_factory = ModelProviderFactory(self.tenant_id)
  751. validated_credentials = model_provider_factory.model_credentials_validate(
  752. provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials
  753. )
  754. for key, value in validated_credentials.items():
  755. if key in provider_credential_secret_variables:
  756. validated_credentials[key] = encrypter.encrypt_token(self.tenant_id, value)
  757. return validated_credentials
  758. if session:
  759. return _validate(session)
  760. else:
  761. with Session(db.engine) as new_session:
  762. return _validate(new_session)
  763. def create_custom_model_credential(
  764. self, model_type: ModelType, model: str, credentials: dict, credential_name: str | None
  765. ) -> None:
  766. """
  767. Create a custom model credential.
  768. :param model_type: model type
  769. :param model: model name
  770. :param credentials: model credentials dict
  771. :return:
  772. """
  773. with Session(db.engine) as session:
  774. if credential_name:
  775. if self._check_custom_model_credential_name_exists(
  776. model=model, model_type=model_type, credential_name=credential_name, session=session
  777. ):
  778. raise ValueError(f"Model credential with name '{credential_name}' already exists for {model}.")
  779. else:
  780. credential_name = self._generate_custom_model_credential_name(
  781. model=model, model_type=model_type, session=session
  782. )
  783. # validate custom model config
  784. credentials = self.validate_custom_model_credentials(
  785. model_type=model_type, model=model, credentials=credentials, session=session
  786. )
  787. provider_model_record = self._get_custom_model_record(model_type=model_type, model=model, session=session)
  788. try:
  789. credential = ProviderModelCredential(
  790. tenant_id=self.tenant_id,
  791. provider_name=self.provider.provider,
  792. model_name=model,
  793. model_type=model_type.to_origin_model_type(),
  794. encrypted_config=json.dumps(credentials),
  795. credential_name=credential_name,
  796. )
  797. session.add(credential)
  798. session.flush()
  799. # save provider model
  800. if not provider_model_record:
  801. provider_model_record = ProviderModel(
  802. tenant_id=self.tenant_id,
  803. provider_name=self.provider.provider,
  804. model_name=model,
  805. model_type=model_type.to_origin_model_type(),
  806. credential_id=credential.id,
  807. is_valid=True,
  808. )
  809. session.add(provider_model_record)
  810. session.commit()
  811. provider_model_credentials_cache = ProviderCredentialsCache(
  812. tenant_id=self.tenant_id,
  813. identity_id=provider_model_record.id,
  814. cache_type=ProviderCredentialsCacheType.MODEL,
  815. )
  816. provider_model_credentials_cache.delete()
  817. except Exception:
  818. session.rollback()
  819. raise
  820. def update_custom_model_credential(
  821. self, model_type: ModelType, model: str, credentials: dict, credential_name: str | None, credential_id: str
  822. ) -> None:
  823. """
  824. Update a custom model credential.
  825. :param model_type: model type
  826. :param model: model name
  827. :param credentials: model credentials dict
  828. :param credential_name: credential name
  829. :param credential_id: credential id
  830. :return:
  831. """
  832. with Session(db.engine) as session:
  833. if credential_name and self._check_custom_model_credential_name_exists(
  834. model=model,
  835. model_type=model_type,
  836. credential_name=credential_name,
  837. session=session,
  838. exclude_id=credential_id,
  839. ):
  840. raise ValueError(f"Model credential with name '{credential_name}' already exists for {model}.")
  841. # validate custom model config
  842. credentials = self.validate_custom_model_credentials(
  843. model_type=model_type,
  844. model=model,
  845. credentials=credentials,
  846. credential_id=credential_id,
  847. session=session,
  848. )
  849. provider_model_record = self._get_custom_model_record(model_type=model_type, model=model, session=session)
  850. stmt = select(ProviderModelCredential).where(
  851. ProviderModelCredential.id == credential_id,
  852. ProviderModelCredential.tenant_id == self.tenant_id,
  853. ProviderModelCredential.provider_name == self.provider.provider,
  854. ProviderModelCredential.model_name == model,
  855. ProviderModelCredential.model_type == model_type.to_origin_model_type(),
  856. )
  857. credential_record = session.execute(stmt).scalar_one_or_none()
  858. if not credential_record:
  859. raise ValueError("Credential record not found.")
  860. try:
  861. # Update credential
  862. credential_record.encrypted_config = json.dumps(credentials)
  863. credential_record.updated_at = naive_utc_now()
  864. if credential_name:
  865. credential_record.credential_name = credential_name
  866. session.commit()
  867. if provider_model_record and provider_model_record.credential_id == credential_id:
  868. provider_model_credentials_cache = ProviderCredentialsCache(
  869. tenant_id=self.tenant_id,
  870. identity_id=provider_model_record.id,
  871. cache_type=ProviderCredentialsCacheType.MODEL,
  872. )
  873. provider_model_credentials_cache.delete()
  874. self._update_load_balancing_configs_with_credential(
  875. credential_id=credential_id,
  876. credential_record=credential_record,
  877. credential_source="custom_model",
  878. session=session,
  879. )
  880. except Exception:
  881. session.rollback()
  882. raise
  883. def delete_custom_model_credential(self, model_type: ModelType, model: str, credential_id: str) -> None:
  884. """
  885. Delete a saved provider credential (by credential_id).
  886. :param credential_id: credential id
  887. :return:
  888. """
  889. with Session(db.engine) as session:
  890. stmt = select(ProviderModelCredential).where(
  891. ProviderModelCredential.id == credential_id,
  892. ProviderModelCredential.tenant_id == self.tenant_id,
  893. ProviderModelCredential.provider_name == self.provider.provider,
  894. ProviderModelCredential.model_name == model,
  895. ProviderModelCredential.model_type == model_type.to_origin_model_type(),
  896. )
  897. credential_record = session.execute(stmt).scalar_one_or_none()
  898. if not credential_record:
  899. raise ValueError("Credential record not found.")
  900. lb_stmt = select(LoadBalancingModelConfig).where(
  901. LoadBalancingModelConfig.tenant_id == self.tenant_id,
  902. LoadBalancingModelConfig.provider_name == self.provider.provider,
  903. LoadBalancingModelConfig.credential_id == credential_id,
  904. LoadBalancingModelConfig.credential_source_type == "custom_model",
  905. )
  906. lb_configs_using_credential = session.execute(lb_stmt).scalars().all()
  907. try:
  908. for lb_config in lb_configs_using_credential:
  909. lb_credentials_cache = ProviderCredentialsCache(
  910. tenant_id=self.tenant_id,
  911. identity_id=lb_config.id,
  912. cache_type=ProviderCredentialsCacheType.LOAD_BALANCING_MODEL,
  913. )
  914. lb_credentials_cache.delete()
  915. session.delete(lb_config)
  916. # Check if this is the currently active credential
  917. provider_model_record = self._get_custom_model_record(model_type, model, session=session)
  918. # Check available credentials count BEFORE deleting
  919. # if this is the last credential, we need to delete the custom model record
  920. count_stmt = select(func.count(ProviderModelCredential.id)).where(
  921. ProviderModelCredential.tenant_id == self.tenant_id,
  922. ProviderModelCredential.provider_name == self.provider.provider,
  923. ProviderModelCredential.model_name == model,
  924. ProviderModelCredential.model_type == model_type.to_origin_model_type(),
  925. )
  926. available_credentials_count = session.execute(count_stmt).scalar() or 0
  927. session.delete(credential_record)
  928. if provider_model_record and available_credentials_count <= 1:
  929. # If all credentials are deleted, delete the custom model record
  930. session.delete(provider_model_record)
  931. elif provider_model_record and provider_model_record.credential_id == credential_id:
  932. provider_model_record.credential_id = None
  933. provider_model_record.updated_at = naive_utc_now()
  934. provider_model_credentials_cache = ProviderCredentialsCache(
  935. tenant_id=self.tenant_id,
  936. identity_id=provider_model_record.id,
  937. cache_type=ProviderCredentialsCacheType.PROVIDER,
  938. )
  939. provider_model_credentials_cache.delete()
  940. session.commit()
  941. except Exception:
  942. session.rollback()
  943. raise
  944. def add_model_credential_to_model(self, model_type: ModelType, model: str, credential_id: str) -> None:
  945. """
  946. if model list exist this custom model, switch the custom model credential.
  947. if model list not exist this custom model, use the credential to add a new custom model record.
  948. :param model_type: model type
  949. :param model: model name
  950. :param credential_id: credential id
  951. :return:
  952. """
  953. with Session(db.engine) as session:
  954. stmt = select(ProviderModelCredential).where(
  955. ProviderModelCredential.id == credential_id,
  956. ProviderModelCredential.tenant_id == self.tenant_id,
  957. ProviderModelCredential.provider_name == self.provider.provider,
  958. ProviderModelCredential.model_name == model,
  959. ProviderModelCredential.model_type == model_type.to_origin_model_type(),
  960. )
  961. credential_record = session.execute(stmt).scalar_one_or_none()
  962. if not credential_record:
  963. raise ValueError("Credential record not found.")
  964. # validate custom model config
  965. provider_model_record = self._get_custom_model_record(model_type=model_type, model=model, session=session)
  966. if not provider_model_record:
  967. # create provider model record
  968. provider_model_record = ProviderModel(
  969. tenant_id=self.tenant_id,
  970. provider_name=self.provider.provider,
  971. model_name=model,
  972. model_type=model_type.to_origin_model_type(),
  973. is_valid=True,
  974. credential_id=credential_id,
  975. )
  976. else:
  977. if provider_model_record.credential_id == credential_record.id:
  978. raise ValueError("Can't add same credential")
  979. provider_model_record.credential_id = credential_record.id
  980. provider_model_record.updated_at = naive_utc_now()
  981. session.add(provider_model_record)
  982. session.commit()
  983. def switch_custom_model_credential(self, model_type: ModelType, model: str, credential_id: str) -> None:
  984. """
  985. switch the custom model credential.
  986. :param model_type: model type
  987. :param model: model name
  988. :param credential_id: credential id
  989. :return:
  990. """
  991. with Session(db.engine) as session:
  992. stmt = select(ProviderModelCredential).where(
  993. ProviderModelCredential.id == credential_id,
  994. ProviderModelCredential.tenant_id == self.tenant_id,
  995. ProviderModelCredential.provider_name == self.provider.provider,
  996. ProviderModelCredential.model_name == model,
  997. ProviderModelCredential.model_type == model_type.to_origin_model_type(),
  998. )
  999. credential_record = session.execute(stmt).scalar_one_or_none()
  1000. if not credential_record:
  1001. raise ValueError("Credential record not found.")
  1002. provider_model_record = self._get_custom_model_record(model_type=model_type, model=model, session=session)
  1003. if not provider_model_record:
  1004. raise ValueError("The custom model record not found.")
  1005. provider_model_record.credential_id = credential_record.id
  1006. provider_model_record.updated_at = naive_utc_now()
  1007. session.add(provider_model_record)
  1008. session.commit()
  1009. def delete_custom_model(self, model_type: ModelType, model: str) -> None:
  1010. """
  1011. Delete custom model.
  1012. :param model_type: model type
  1013. :param model: model name
  1014. :return:
  1015. """
  1016. with Session(db.engine) as session:
  1017. # get provider model
  1018. provider_model_record = self._get_custom_model_record(model_type=model_type, model=model, session=session)
  1019. # delete provider model
  1020. if provider_model_record:
  1021. session.delete(provider_model_record)
  1022. session.commit()
  1023. provider_model_credentials_cache = ProviderCredentialsCache(
  1024. tenant_id=self.tenant_id,
  1025. identity_id=provider_model_record.id,
  1026. cache_type=ProviderCredentialsCacheType.MODEL,
  1027. )
  1028. provider_model_credentials_cache.delete()
  1029. def _get_provider_model_setting(
  1030. self, model_type: ModelType, model: str, session: Session
  1031. ) -> ProviderModelSetting | None:
  1032. """
  1033. Get provider model setting.
  1034. """
  1035. model_provider_id = ModelProviderID(self.provider.provider)
  1036. provider_names = [self.provider.provider]
  1037. if model_provider_id.is_langgenius():
  1038. provider_names.append(model_provider_id.provider_name)
  1039. stmt = select(ProviderModelSetting).where(
  1040. ProviderModelSetting.tenant_id == self.tenant_id,
  1041. ProviderModelSetting.provider_name.in_(provider_names),
  1042. ProviderModelSetting.model_type == model_type.to_origin_model_type(),
  1043. ProviderModelSetting.model_name == model,
  1044. )
  1045. return session.execute(stmt).scalars().first()
  1046. def enable_model(self, model_type: ModelType, model: str) -> ProviderModelSetting:
  1047. """
  1048. Enable model.
  1049. :param model_type: model type
  1050. :param model: model name
  1051. :return:
  1052. """
  1053. with Session(db.engine) as session:
  1054. model_setting = self._get_provider_model_setting(model_type=model_type, model=model, session=session)
  1055. if model_setting:
  1056. model_setting.enabled = True
  1057. model_setting.updated_at = naive_utc_now()
  1058. else:
  1059. model_setting = ProviderModelSetting(
  1060. tenant_id=self.tenant_id,
  1061. provider_name=self.provider.provider,
  1062. model_type=model_type.to_origin_model_type(),
  1063. model_name=model,
  1064. enabled=True,
  1065. )
  1066. session.add(model_setting)
  1067. session.commit()
  1068. return model_setting
  1069. def disable_model(self, model_type: ModelType, model: str) -> ProviderModelSetting:
  1070. """
  1071. Disable model.
  1072. :param model_type: model type
  1073. :param model: model name
  1074. :return:
  1075. """
  1076. with Session(db.engine) as session:
  1077. model_setting = self._get_provider_model_setting(model_type=model_type, model=model, session=session)
  1078. if model_setting:
  1079. model_setting.enabled = False
  1080. model_setting.updated_at = naive_utc_now()
  1081. else:
  1082. model_setting = ProviderModelSetting(
  1083. tenant_id=self.tenant_id,
  1084. provider_name=self.provider.provider,
  1085. model_type=model_type.to_origin_model_type(),
  1086. model_name=model,
  1087. enabled=False,
  1088. )
  1089. session.add(model_setting)
  1090. session.commit()
  1091. return model_setting
  1092. def get_provider_model_setting(self, model_type: ModelType, model: str) -> Optional[ProviderModelSetting]:
  1093. """
  1094. Get provider model setting.
  1095. :param model_type: model type
  1096. :param model: model name
  1097. :return:
  1098. """
  1099. with Session(db.engine) as session:
  1100. return self._get_provider_model_setting(model_type=model_type, model=model, session=session)
  1101. def enable_model_load_balancing(self, model_type: ModelType, model: str) -> ProviderModelSetting:
  1102. """
  1103. Enable model load balancing.
  1104. :param model_type: model type
  1105. :param model: model name
  1106. :return:
  1107. """
  1108. model_provider_id = ModelProviderID(self.provider.provider)
  1109. provider_names = [self.provider.provider]
  1110. if model_provider_id.is_langgenius():
  1111. provider_names.append(model_provider_id.provider_name)
  1112. with Session(db.engine) as session:
  1113. stmt = select(func.count(LoadBalancingModelConfig.id)).where(
  1114. LoadBalancingModelConfig.tenant_id == self.tenant_id,
  1115. LoadBalancingModelConfig.provider_name.in_(provider_names),
  1116. LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
  1117. LoadBalancingModelConfig.model_name == model,
  1118. )
  1119. load_balancing_config_count = session.execute(stmt).scalar() or 0
  1120. if load_balancing_config_count <= 1:
  1121. raise ValueError("Model load balancing configuration must be more than 1.")
  1122. model_setting = self._get_provider_model_setting(model_type=model_type, model=model, session=session)
  1123. if model_setting:
  1124. model_setting.load_balancing_enabled = True
  1125. model_setting.updated_at = naive_utc_now()
  1126. else:
  1127. model_setting = ProviderModelSetting(
  1128. tenant_id=self.tenant_id,
  1129. provider_name=self.provider.provider,
  1130. model_type=model_type.to_origin_model_type(),
  1131. model_name=model,
  1132. load_balancing_enabled=True,
  1133. )
  1134. session.add(model_setting)
  1135. session.commit()
  1136. return model_setting
  1137. def disable_model_load_balancing(self, model_type: ModelType, model: str) -> ProviderModelSetting:
  1138. """
  1139. Disable model load balancing.
  1140. :param model_type: model type
  1141. :param model: model name
  1142. :return:
  1143. """
  1144. with Session(db.engine) as session:
  1145. model_setting = self._get_provider_model_setting(model_type=model_type, model=model, session=session)
  1146. if model_setting:
  1147. model_setting.load_balancing_enabled = False
  1148. model_setting.updated_at = naive_utc_now()
  1149. else:
  1150. model_setting = ProviderModelSetting(
  1151. tenant_id=self.tenant_id,
  1152. provider_name=self.provider.provider,
  1153. model_type=model_type.to_origin_model_type(),
  1154. model_name=model,
  1155. load_balancing_enabled=False,
  1156. )
  1157. session.add(model_setting)
  1158. session.commit()
  1159. return model_setting
  1160. def get_model_type_instance(self, model_type: ModelType) -> AIModel:
  1161. """
  1162. Get current model type instance.
  1163. :param model_type: model type
  1164. :return:
  1165. """
  1166. model_provider_factory = ModelProviderFactory(self.tenant_id)
  1167. # Get model instance of LLM
  1168. return model_provider_factory.get_model_type_instance(provider=self.provider.provider, model_type=model_type)
  1169. def get_model_schema(self, model_type: ModelType, model: str, credentials: dict | None) -> AIModelEntity | None:
  1170. """
  1171. Get model schema
  1172. """
  1173. model_provider_factory = ModelProviderFactory(self.tenant_id)
  1174. return model_provider_factory.get_model_schema(
  1175. provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials
  1176. )
  1177. def switch_preferred_provider_type(self, provider_type: ProviderType, session: Session | None = None) -> None:
  1178. """
  1179. Switch preferred provider type.
  1180. :param provider_type:
  1181. :return:
  1182. """
  1183. if provider_type == self.preferred_provider_type:
  1184. return
  1185. if provider_type == ProviderType.SYSTEM and not self.system_configuration.enabled:
  1186. return
  1187. def _switch(s: Session) -> None:
  1188. # get preferred provider
  1189. model_provider_id = ModelProviderID(self.provider.provider)
  1190. provider_names = [self.provider.provider]
  1191. if model_provider_id.is_langgenius():
  1192. provider_names.append(model_provider_id.provider_name)
  1193. stmt = select(TenantPreferredModelProvider).where(
  1194. TenantPreferredModelProvider.tenant_id == self.tenant_id,
  1195. TenantPreferredModelProvider.provider_name.in_(provider_names),
  1196. )
  1197. preferred_model_provider = s.execute(stmt).scalars().first()
  1198. if preferred_model_provider:
  1199. preferred_model_provider.preferred_provider_type = provider_type.value
  1200. else:
  1201. preferred_model_provider = TenantPreferredModelProvider(
  1202. tenant_id=self.tenant_id,
  1203. provider_name=self.provider.provider,
  1204. preferred_provider_type=provider_type.value,
  1205. )
  1206. s.add(preferred_model_provider)
  1207. s.commit()
  1208. if session:
  1209. return _switch(session)
  1210. else:
  1211. with Session(db.engine) as session:
  1212. return _switch(session)
  1213. def extract_secret_variables(self, credential_form_schemas: list[CredentialFormSchema]) -> list[str]:
  1214. """
  1215. Extract secret input form variables.
  1216. :param credential_form_schemas:
  1217. :return:
  1218. """
  1219. secret_input_form_variables = []
  1220. for credential_form_schema in credential_form_schemas:
  1221. if credential_form_schema.type == FormType.SECRET_INPUT:
  1222. secret_input_form_variables.append(credential_form_schema.variable)
  1223. return secret_input_form_variables
  1224. def obfuscated_credentials(self, credentials: dict, credential_form_schemas: list[CredentialFormSchema]) -> dict:
  1225. """
  1226. Obfuscated credentials.
  1227. :param credentials: credentials
  1228. :param credential_form_schemas: credential form schemas
  1229. :return:
  1230. """
  1231. # Get provider credential secret variables
  1232. credential_secret_variables = self.extract_secret_variables(credential_form_schemas)
  1233. # Obfuscate provider credentials
  1234. copy_credentials = credentials.copy()
  1235. for key, value in copy_credentials.items():
  1236. if key in credential_secret_variables:
  1237. copy_credentials[key] = encrypter.obfuscated_token(value)
  1238. return copy_credentials
  1239. def get_provider_model(
  1240. self, model_type: ModelType, model: str, only_active: bool = False
  1241. ) -> Optional[ModelWithProviderEntity]:
  1242. """
  1243. Get provider model.
  1244. :param model_type: model type
  1245. :param model: model name
  1246. :param only_active: return active model only
  1247. :return:
  1248. """
  1249. provider_models = self.get_provider_models(model_type, only_active, model)
  1250. for provider_model in provider_models:
  1251. if provider_model.model == model:
  1252. return provider_model
  1253. return None
  1254. def get_provider_models(
  1255. self, model_type: Optional[ModelType] = None, only_active: bool = False, model: Optional[str] = None
  1256. ) -> list[ModelWithProviderEntity]:
  1257. """
  1258. Get provider models.
  1259. :param model_type: model type
  1260. :param only_active: only active models
  1261. :param model: model name
  1262. :return:
  1263. """
  1264. model_provider_factory = ModelProviderFactory(self.tenant_id)
  1265. provider_schema = model_provider_factory.get_provider_schema(self.provider.provider)
  1266. model_types: list[ModelType] = []
  1267. if model_type:
  1268. model_types.append(model_type)
  1269. else:
  1270. model_types = list(provider_schema.supported_model_types)
  1271. # Group model settings by model type and model
  1272. model_setting_map: defaultdict[ModelType, dict[str, ModelSettings]] = defaultdict(dict)
  1273. for model_setting in self.model_settings:
  1274. model_setting_map[model_setting.model_type][model_setting.model] = model_setting
  1275. if self.using_provider_type == ProviderType.SYSTEM:
  1276. provider_models = self._get_system_provider_models(
  1277. model_types=model_types, provider_schema=provider_schema, model_setting_map=model_setting_map
  1278. )
  1279. else:
  1280. provider_models = self._get_custom_provider_models(
  1281. model_types=model_types,
  1282. provider_schema=provider_schema,
  1283. model_setting_map=model_setting_map,
  1284. model=model,
  1285. )
  1286. if only_active:
  1287. provider_models = [m for m in provider_models if m.status == ModelStatus.ACTIVE]
  1288. # resort provider_models
  1289. # Optimize sorting logic: first sort by provider.position order, then by model_type.value
  1290. # Get the position list for model types (retrieve only once for better performance)
  1291. model_type_positions = {}
  1292. if hasattr(self.provider, "position") and self.provider.position:
  1293. model_type_positions = self.provider.position
  1294. def get_sort_key(model: ModelWithProviderEntity):
  1295. # Get the position list for the current model type
  1296. positions = model_type_positions.get(model.model_type.value, [])
  1297. # If the model name is in the position list, use its index for sorting
  1298. # Otherwise use a large value (list length) to place undefined models at the end
  1299. position_index = positions.index(model.model) if model.model in positions else len(positions)
  1300. # Return composite sort key: (model_type value, model position index)
  1301. return (model.model_type.value, position_index)
  1302. # Sort using the composite sort key
  1303. return sorted(provider_models, key=get_sort_key)
  1304. def _get_system_provider_models(
  1305. self,
  1306. model_types: Sequence[ModelType],
  1307. provider_schema: ProviderEntity,
  1308. model_setting_map: dict[ModelType, dict[str, ModelSettings]],
  1309. ) -> list[ModelWithProviderEntity]:
  1310. """
  1311. Get system provider models.
  1312. :param model_types: model types
  1313. :param provider_schema: provider schema
  1314. :param model_setting_map: model setting map
  1315. :return:
  1316. """
  1317. provider_models = []
  1318. for model_type in model_types:
  1319. for m in provider_schema.models:
  1320. if m.model_type != model_type:
  1321. continue
  1322. status = ModelStatus.ACTIVE
  1323. if m.model_type in model_setting_map and m.model in model_setting_map[m.model_type]:
  1324. model_setting = model_setting_map[m.model_type][m.model]
  1325. if model_setting.enabled is False:
  1326. status = ModelStatus.DISABLED
  1327. provider_models.append(
  1328. ModelWithProviderEntity(
  1329. model=m.model,
  1330. label=m.label,
  1331. model_type=m.model_type,
  1332. features=m.features,
  1333. fetch_from=m.fetch_from,
  1334. model_properties=m.model_properties,
  1335. deprecated=m.deprecated,
  1336. provider=SimpleModelProviderEntity(self.provider),
  1337. status=status,
  1338. )
  1339. )
  1340. if self.provider.provider not in original_provider_configurate_methods:
  1341. original_provider_configurate_methods[self.provider.provider] = []
  1342. for configurate_method in provider_schema.configurate_methods:
  1343. original_provider_configurate_methods[self.provider.provider].append(configurate_method)
  1344. should_use_custom_model = False
  1345. if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]:
  1346. should_use_custom_model = True
  1347. for quota_configuration in self.system_configuration.quota_configurations:
  1348. if self.system_configuration.current_quota_type != quota_configuration.quota_type:
  1349. continue
  1350. restrict_models = quota_configuration.restrict_models
  1351. if len(restrict_models) == 0:
  1352. break
  1353. if should_use_custom_model:
  1354. if original_provider_configurate_methods[self.provider.provider] == [
  1355. ConfigurateMethod.CUSTOMIZABLE_MODEL
  1356. ]:
  1357. # only customizable model
  1358. for restrict_model in restrict_models:
  1359. copy_credentials = (
  1360. self.system_configuration.credentials.copy()
  1361. if self.system_configuration.credentials
  1362. else {}
  1363. )
  1364. if restrict_model.base_model_name:
  1365. copy_credentials["base_model_name"] = restrict_model.base_model_name
  1366. try:
  1367. custom_model_schema = self.get_model_schema(
  1368. model_type=restrict_model.model_type,
  1369. model=restrict_model.model,
  1370. credentials=copy_credentials,
  1371. )
  1372. except Exception as ex:
  1373. logger.warning("get custom model schema failed, %s", ex)
  1374. continue
  1375. if not custom_model_schema:
  1376. continue
  1377. if custom_model_schema.model_type not in model_types:
  1378. continue
  1379. status = ModelStatus.ACTIVE
  1380. if (
  1381. custom_model_schema.model_type in model_setting_map
  1382. and custom_model_schema.model in model_setting_map[custom_model_schema.model_type]
  1383. ):
  1384. model_setting = model_setting_map[custom_model_schema.model_type][custom_model_schema.model]
  1385. if model_setting.enabled is False:
  1386. status = ModelStatus.DISABLED
  1387. provider_models.append(
  1388. ModelWithProviderEntity(
  1389. model=custom_model_schema.model,
  1390. label=custom_model_schema.label,
  1391. model_type=custom_model_schema.model_type,
  1392. features=custom_model_schema.features,
  1393. fetch_from=FetchFrom.PREDEFINED_MODEL,
  1394. model_properties=custom_model_schema.model_properties,
  1395. deprecated=custom_model_schema.deprecated,
  1396. provider=SimpleModelProviderEntity(self.provider),
  1397. status=status,
  1398. )
  1399. )
  1400. # if llm name not in restricted llm list, remove it
  1401. restrict_model_names = [rm.model for rm in restrict_models]
  1402. for model in provider_models:
  1403. if model.model_type == ModelType.LLM and model.model not in restrict_model_names:
  1404. model.status = ModelStatus.NO_PERMISSION
  1405. elif not quota_configuration.is_valid:
  1406. model.status = ModelStatus.QUOTA_EXCEEDED
  1407. return provider_models
  1408. def _get_custom_provider_models(
  1409. self,
  1410. model_types: Sequence[ModelType],
  1411. provider_schema: ProviderEntity,
  1412. model_setting_map: dict[ModelType, dict[str, ModelSettings]],
  1413. model: Optional[str] = None,
  1414. ) -> list[ModelWithProviderEntity]:
  1415. """
  1416. Get custom provider models.
  1417. :param model_types: model types
  1418. :param provider_schema: provider schema
  1419. :param model_setting_map: model setting map
  1420. :return:
  1421. """
  1422. provider_models = []
  1423. credentials = None
  1424. if self.custom_configuration.provider:
  1425. credentials = self.custom_configuration.provider.credentials
  1426. for model_type in model_types:
  1427. if model_type not in self.provider.supported_model_types:
  1428. continue
  1429. for m in provider_schema.models:
  1430. if m.model_type != model_type:
  1431. continue
  1432. status = ModelStatus.ACTIVE if credentials else ModelStatus.NO_CONFIGURE
  1433. load_balancing_enabled = False
  1434. has_invalid_load_balancing_configs = False
  1435. if m.model_type in model_setting_map and m.model in model_setting_map[m.model_type]:
  1436. model_setting = model_setting_map[m.model_type][m.model]
  1437. if model_setting.enabled is False:
  1438. status = ModelStatus.DISABLED
  1439. provider_model_lb_configs = [
  1440. config
  1441. for config in model_setting.load_balancing_configs
  1442. if config.credential_source_type != "custom_model"
  1443. ]
  1444. load_balancing_enabled = model_setting.load_balancing_enabled
  1445. # when the user enable load_balancing but available configs are less than 2 display warning
  1446. has_invalid_load_balancing_configs = load_balancing_enabled and len(provider_model_lb_configs) < 2
  1447. provider_models.append(
  1448. ModelWithProviderEntity(
  1449. model=m.model,
  1450. label=m.label,
  1451. model_type=m.model_type,
  1452. features=m.features,
  1453. fetch_from=m.fetch_from,
  1454. model_properties=m.model_properties,
  1455. deprecated=m.deprecated,
  1456. provider=SimpleModelProviderEntity(self.provider),
  1457. status=status,
  1458. load_balancing_enabled=load_balancing_enabled,
  1459. has_invalid_load_balancing_configs=has_invalid_load_balancing_configs,
  1460. )
  1461. )
  1462. # custom models
  1463. for model_configuration in self.custom_configuration.models:
  1464. if model_configuration.model_type not in model_types:
  1465. continue
  1466. if model_configuration.unadded_to_model_list:
  1467. continue
  1468. if model and model != model_configuration.model:
  1469. continue
  1470. try:
  1471. custom_model_schema = self.get_model_schema(
  1472. model_type=model_configuration.model_type,
  1473. model=model_configuration.model,
  1474. credentials=model_configuration.credentials,
  1475. )
  1476. except Exception as ex:
  1477. logger.warning("get custom model schema failed, %s", ex)
  1478. continue
  1479. if not custom_model_schema:
  1480. continue
  1481. status = ModelStatus.ACTIVE
  1482. load_balancing_enabled = False
  1483. has_invalid_load_balancing_configs = False
  1484. if (
  1485. custom_model_schema.model_type in model_setting_map
  1486. and custom_model_schema.model in model_setting_map[custom_model_schema.model_type]
  1487. ):
  1488. model_setting = model_setting_map[custom_model_schema.model_type][custom_model_schema.model]
  1489. if model_setting.enabled is False:
  1490. status = ModelStatus.DISABLED
  1491. custom_model_lb_configs = [
  1492. config
  1493. for config in model_setting.load_balancing_configs
  1494. if config.credential_source_type != "provider"
  1495. ]
  1496. load_balancing_enabled = model_setting.load_balancing_enabled
  1497. # when the user enable load_balancing but available configs are less than 2 display warning
  1498. has_invalid_load_balancing_configs = load_balancing_enabled and len(custom_model_lb_configs) < 2
  1499. if len(model_configuration.available_model_credentials) > 0 and not model_configuration.credentials:
  1500. status = ModelStatus.CREDENTIAL_REMOVED
  1501. provider_models.append(
  1502. ModelWithProviderEntity(
  1503. model=custom_model_schema.model,
  1504. label=custom_model_schema.label,
  1505. model_type=custom_model_schema.model_type,
  1506. features=custom_model_schema.features,
  1507. fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
  1508. model_properties=custom_model_schema.model_properties,
  1509. deprecated=custom_model_schema.deprecated,
  1510. provider=SimpleModelProviderEntity(self.provider),
  1511. status=status,
  1512. load_balancing_enabled=load_balancing_enabled,
  1513. has_invalid_load_balancing_configs=has_invalid_load_balancing_configs,
  1514. )
  1515. )
  1516. return provider_models
  1517. class ProviderConfigurations(BaseModel):
  1518. """
  1519. Model class for provider configuration dict.
  1520. """
  1521. tenant_id: str
  1522. configurations: dict[str, ProviderConfiguration] = Field(default_factory=dict)
  1523. def __init__(self, tenant_id: str):
  1524. super().__init__(tenant_id=tenant_id)
  1525. def get_models(
  1526. self, provider: Optional[str] = None, model_type: Optional[ModelType] = None, only_active: bool = False
  1527. ) -> list[ModelWithProviderEntity]:
  1528. """
  1529. Get available models.
  1530. If preferred provider type is `system`:
  1531. Get the current **system mode** if provider supported,
  1532. if all system modes are not available (no quota), it is considered to be the **custom credential mode**.
  1533. If there is no model configured in custom mode, it is treated as no_configure.
  1534. system > custom > no_configure
  1535. If preferred provider type is `custom`:
  1536. If custom credentials are configured, it is treated as custom mode.
  1537. Otherwise, get the current **system mode** if supported,
  1538. If all system modes are not available (no quota), it is treated as no_configure.
  1539. custom > system > no_configure
  1540. If real mode is `system`, use system credentials to get models,
  1541. paid quotas > provider free quotas > system free quotas
  1542. include pre-defined models (exclude GPT-4, status marked as `no_permission`).
  1543. If real mode is `custom`, use workspace custom credentials to get models,
  1544. include pre-defined models, custom models(manual append).
  1545. If real mode is `no_configure`, only return pre-defined models from `model runtime`.
  1546. (model status marked as `no_configure` if preferred provider type is `custom` otherwise `quota_exceeded`)
  1547. model status marked as `active` is available.
  1548. :param provider: provider name
  1549. :param model_type: model type
  1550. :param only_active: only active models
  1551. :return:
  1552. """
  1553. all_models = []
  1554. for provider_configuration in self.values():
  1555. if provider and provider_configuration.provider.provider != provider:
  1556. continue
  1557. all_models.extend(provider_configuration.get_provider_models(model_type, only_active))
  1558. return all_models
  1559. def to_list(self) -> list[ProviderConfiguration]:
  1560. """
  1561. Convert to list.
  1562. :return:
  1563. """
  1564. return list(self.values())
  1565. def __getitem__(self, key):
  1566. if "/" not in key:
  1567. key = str(ModelProviderID(key))
  1568. return self.configurations[key]
  1569. def __setitem__(self, key, value):
  1570. self.configurations[key] = value
  1571. def __iter__(self):
  1572. return iter(self.configurations)
  1573. def values(self) -> Iterator[ProviderConfiguration]:
  1574. return iter(self.configurations.values())
  1575. def get(self, key, default=None) -> ProviderConfiguration | None:
  1576. if "/" not in key:
  1577. key = str(ModelProviderID(key))
  1578. return self.configurations.get(key, default) # type: ignore
  1579. class ProviderModelBundle(BaseModel):
  1580. """
  1581. Provider model bundle.
  1582. """
  1583. configuration: ProviderConfiguration
  1584. model_type_instance: AIModel
  1585. # pydantic configs
  1586. model_config = ConfigDict(arbitrary_types_allowed=True, protected_namespaces=())