Ви не можете вибрати більше 25 тем Теми мають розпочинатися з літери або цифри, можуть містити дефіси (-) і не повинні перевищувати 35 символів.


  1. import logging
  2. from flask_login import current_user
  3. from flask_restx import Resource, reqparse
  4. from werkzeug.exceptions import Forbidden
  5. from controllers.console import api
  6. from controllers.console.wraps import account_initialization_required, setup_required
  7. from core.model_runtime.entities.model_entities import ModelType
  8. from core.model_runtime.errors.validate import CredentialsValidateFailedError
  9. from core.model_runtime.utils.encoders import jsonable_encoder
  10. from libs.helper import StrLen, uuid_value
  11. from libs.login import login_required
  12. from services.model_load_balancing_service import ModelLoadBalancingService
  13. from services.model_provider_service import ModelProviderService
  14. logger = logging.getLogger(__name__)
  15. class DefaultModelApi(Resource):
  16. @setup_required
  17. @login_required
  18. @account_initialization_required
  19. def get(self):
  20. parser = reqparse.RequestParser()
  21. parser.add_argument(
  22. "model_type",
  23. type=str,
  24. required=True,
  25. nullable=False,
  26. choices=[mt.value for mt in ModelType],
  27. location="args",
  28. )
  29. args = parser.parse_args()
  30. tenant_id = current_user.current_tenant_id
  31. model_provider_service = ModelProviderService()
  32. default_model_entity = model_provider_service.get_default_model_of_model_type(
  33. tenant_id=tenant_id, model_type=args["model_type"]
  34. )
  35. return jsonable_encoder({"data": default_model_entity})
  36. @setup_required
  37. @login_required
  38. @account_initialization_required
  39. def post(self):
  40. if not current_user.is_admin_or_owner:
  41. raise Forbidden()
  42. parser = reqparse.RequestParser()
  43. parser.add_argument("model_settings", type=list, required=True, nullable=False, location="json")
  44. args = parser.parse_args()
  45. tenant_id = current_user.current_tenant_id
  46. model_provider_service = ModelProviderService()
  47. model_settings = args["model_settings"]
  48. for model_setting in model_settings:
  49. if "model_type" not in model_setting or model_setting["model_type"] not in [mt.value for mt in ModelType]:
  50. raise ValueError("invalid model type")
  51. if "provider" not in model_setting:
  52. continue
  53. if "model" not in model_setting:
  54. raise ValueError("invalid model")
  55. try:
  56. model_provider_service.update_default_model_of_model_type(
  57. tenant_id=tenant_id,
  58. model_type=model_setting["model_type"],
  59. provider=model_setting["provider"],
  60. model=model_setting["model"],
  61. )
  62. except Exception as ex:
  63. logger.exception(
  64. "Failed to update default model, model type: %s, model: %s",
  65. model_setting["model_type"],
  66. model_setting.get("model"),
  67. )
  68. raise ex
  69. return {"result": "success"}
  70. class ModelProviderModelApi(Resource):
  71. @setup_required
  72. @login_required
  73. @account_initialization_required
  74. def get(self, provider):
  75. tenant_id = current_user.current_tenant_id
  76. model_provider_service = ModelProviderService()
  77. models = model_provider_service.get_models_by_provider(tenant_id=tenant_id, provider=provider)
  78. return jsonable_encoder({"data": models})
  79. @setup_required
  80. @login_required
  81. @account_initialization_required
  82. def post(self, provider: str):
  83. # To save the model's load balance configs
  84. if not current_user.is_admin_or_owner:
  85. raise Forbidden()
  86. tenant_id = current_user.current_tenant_id
  87. parser = reqparse.RequestParser()
  88. parser.add_argument("model", type=str, required=True, nullable=False, location="json")
  89. parser.add_argument(
  90. "model_type",
  91. type=str,
  92. required=True,
  93. nullable=False,
  94. choices=[mt.value for mt in ModelType],
  95. location="json",
  96. )
  97. parser.add_argument("load_balancing", type=dict, required=False, nullable=True, location="json")
  98. parser.add_argument("config_from", type=str, required=False, nullable=True, location="json")
  99. parser.add_argument("credential_id", type=uuid_value, required=False, nullable=True, location="json")
  100. args = parser.parse_args()
  101. if args.get("config_from", "") == "custom-model":
  102. if not args.get("credential_id"):
  103. raise ValueError("credential_id is required when configuring a custom-model")
  104. service = ModelProviderService()
  105. service.switch_active_custom_model_credential(
  106. tenant_id=current_user.current_tenant_id,
  107. provider=provider,
  108. model_type=args["model_type"],
  109. model=args["model"],
  110. credential_id=args["credential_id"],
  111. )
  112. model_load_balancing_service = ModelLoadBalancingService()
  113. if "load_balancing" in args and args["load_balancing"] and "configs" in args["load_balancing"]:
  114. # save load balancing configs
  115. model_load_balancing_service.update_load_balancing_configs(
  116. tenant_id=tenant_id,
  117. provider=provider,
  118. model=args["model"],
  119. model_type=args["model_type"],
  120. configs=args["load_balancing"]["configs"],
  121. config_from=args.get("config_from", ""),
  122. )
  123. if args.get("load_balancing", {}).get("enabled"):
  124. model_load_balancing_service.enable_model_load_balancing(
  125. tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"]
  126. )
  127. else:
  128. model_load_balancing_service.disable_model_load_balancing(
  129. tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"]
  130. )
  131. return {"result": "success"}, 200
  132. @setup_required
  133. @login_required
  134. @account_initialization_required
  135. def delete(self, provider: str):
  136. if not current_user.is_admin_or_owner:
  137. raise Forbidden()
  138. tenant_id = current_user.current_tenant_id
  139. parser = reqparse.RequestParser()
  140. parser.add_argument("model", type=str, required=True, nullable=False, location="json")
  141. parser.add_argument(
  142. "model_type",
  143. type=str,
  144. required=True,
  145. nullable=False,
  146. choices=[mt.value for mt in ModelType],
  147. location="json",
  148. )
  149. args = parser.parse_args()
  150. model_provider_service = ModelProviderService()
  151. model_provider_service.remove_model(
  152. tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"]
  153. )
  154. return {"result": "success"}, 204
  155. class ModelProviderModelCredentialApi(Resource):
  156. @setup_required
  157. @login_required
  158. @account_initialization_required
  159. def get(self, provider: str):
  160. tenant_id = current_user.current_tenant_id
  161. parser = reqparse.RequestParser()
  162. parser.add_argument("model", type=str, required=True, nullable=False, location="args")
  163. parser.add_argument(
  164. "model_type",
  165. type=str,
  166. required=True,
  167. nullable=False,
  168. choices=[mt.value for mt in ModelType],
  169. location="args",
  170. )
  171. parser.add_argument("config_from", type=str, required=False, nullable=True, location="args")
  172. parser.add_argument("credential_id", type=uuid_value, required=False, nullable=True, location="args")
  173. args = parser.parse_args()
  174. model_provider_service = ModelProviderService()
  175. current_credential = model_provider_service.get_model_credential(
  176. tenant_id=tenant_id,
  177. provider=provider,
  178. model_type=args["model_type"],
  179. model=args["model"],
  180. credential_id=args.get("credential_id"),
  181. )
  182. model_load_balancing_service = ModelLoadBalancingService()
  183. is_load_balancing_enabled, load_balancing_configs = model_load_balancing_service.get_load_balancing_configs(
  184. tenant_id=tenant_id,
  185. provider=provider,
  186. model=args["model"],
  187. model_type=args["model_type"],
  188. config_from=args.get("config_from", ""),
  189. )
  190. if args.get("config_from", "") == "predefined-model":
  191. available_credentials = model_provider_service.provider_manager.get_provider_available_credentials(
  192. tenant_id=tenant_id, provider_name=provider
  193. )
  194. else:
  195. model_type = ModelType.value_of(args["model_type"]).to_origin_model_type()
  196. available_credentials = model_provider_service.provider_manager.get_provider_model_available_credentials(
  197. tenant_id=tenant_id, provider_name=provider, model_type=model_type, model_name=args["model"]
  198. )
  199. return jsonable_encoder(
  200. {
  201. "credentials": current_credential.get("credentials") if current_credential else {},
  202. "current_credential_id": current_credential.get("current_credential_id")
  203. if current_credential
  204. else None,
  205. "current_credential_name": current_credential.get("current_credential_name")
  206. if current_credential
  207. else None,
  208. "load_balancing": {"enabled": is_load_balancing_enabled, "configs": load_balancing_configs},
  209. "available_credentials": available_credentials,
  210. }
  211. )
  212. @setup_required
  213. @login_required
  214. @account_initialization_required
  215. def post(self, provider: str):
  216. if not current_user.is_admin_or_owner:
  217. raise Forbidden()
  218. parser = reqparse.RequestParser()
  219. parser.add_argument("model", type=str, required=True, nullable=False, location="json")
  220. parser.add_argument(
  221. "model_type",
  222. type=str,
  223. required=True,
  224. nullable=False,
  225. choices=[mt.value for mt in ModelType],
  226. location="json",
  227. )
  228. parser.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
  229. parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
  230. args = parser.parse_args()
  231. tenant_id = current_user.current_tenant_id
  232. model_provider_service = ModelProviderService()
  233. try:
  234. model_provider_service.create_model_credential(
  235. tenant_id=tenant_id,
  236. provider=provider,
  237. model=args["model"],
  238. model_type=args["model_type"],
  239. credentials=args["credentials"],
  240. credential_name=args["name"],
  241. )
  242. except CredentialsValidateFailedError as ex:
  243. logger.exception(
  244. "Failed to save model credentials, tenant_id: %s, model: %s, model_type: %s",
  245. tenant_id,
  246. args.get("model"),
  247. args.get("model_type"),
  248. )
  249. raise ValueError(str(ex))
  250. return {"result": "success"}, 201
  251. @setup_required
  252. @login_required
  253. @account_initialization_required
  254. def put(self, provider: str):
  255. if not current_user.is_admin_or_owner:
  256. raise Forbidden()
  257. parser = reqparse.RequestParser()
  258. parser.add_argument("model", type=str, required=True, nullable=False, location="json")
  259. parser.add_argument(
  260. "model_type",
  261. type=str,
  262. required=True,
  263. nullable=False,
  264. choices=[mt.value for mt in ModelType],
  265. location="json",
  266. )
  267. parser.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
  268. parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
  269. parser.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
  270. args = parser.parse_args()
  271. model_provider_service = ModelProviderService()
  272. try:
  273. model_provider_service.update_model_credential(
  274. tenant_id=current_user.current_tenant_id,
  275. provider=provider,
  276. model_type=args["model_type"],
  277. model=args["model"],
  278. credentials=args["credentials"],
  279. credential_id=args["credential_id"],
  280. credential_name=args["name"],
  281. )
  282. except CredentialsValidateFailedError as ex:
  283. raise ValueError(str(ex))
  284. return {"result": "success"}
  285. @setup_required
  286. @login_required
  287. @account_initialization_required
  288. def delete(self, provider: str):
  289. if not current_user.is_admin_or_owner:
  290. raise Forbidden()
  291. parser = reqparse.RequestParser()
  292. parser.add_argument("model", type=str, required=True, nullable=False, location="json")
  293. parser.add_argument(
  294. "model_type",
  295. type=str,
  296. required=True,
  297. nullable=False,
  298. choices=[mt.value for mt in ModelType],
  299. location="json",
  300. )
  301. parser.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json")
  302. args = parser.parse_args()
  303. model_provider_service = ModelProviderService()
  304. model_provider_service.remove_model_credential(
  305. tenant_id=current_user.current_tenant_id,
  306. provider=provider,
  307. model_type=args["model_type"],
  308. model=args["model"],
  309. credential_id=args["credential_id"],
  310. )
  311. return {"result": "success"}, 204
  312. class ModelProviderModelCredentialSwitchApi(Resource):
  313. @setup_required
  314. @login_required
  315. @account_initialization_required
  316. def post(self, provider: str):
  317. if not current_user.is_admin_or_owner:
  318. raise Forbidden()
  319. parser = reqparse.RequestParser()
  320. parser.add_argument("model", type=str, required=True, nullable=False, location="json")
  321. parser.add_argument(
  322. "model_type",
  323. type=str,
  324. required=True,
  325. nullable=False,
  326. choices=[mt.value for mt in ModelType],
  327. location="json",
  328. )
  329. parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
  330. args = parser.parse_args()
  331. service = ModelProviderService()
  332. service.add_model_credential_to_model_list(
  333. tenant_id=current_user.current_tenant_id,
  334. provider=provider,
  335. model_type=args["model_type"],
  336. model=args["model"],
  337. credential_id=args["credential_id"],
  338. )
  339. return {"result": "success"}
  340. class ModelProviderModelEnableApi(Resource):
  341. @setup_required
  342. @login_required
  343. @account_initialization_required
  344. def patch(self, provider: str):
  345. tenant_id = current_user.current_tenant_id
  346. parser = reqparse.RequestParser()
  347. parser.add_argument("model", type=str, required=True, nullable=False, location="json")
  348. parser.add_argument(
  349. "model_type",
  350. type=str,
  351. required=True,
  352. nullable=False,
  353. choices=[mt.value for mt in ModelType],
  354. location="json",
  355. )
  356. args = parser.parse_args()
  357. model_provider_service = ModelProviderService()
  358. model_provider_service.enable_model(
  359. tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"]
  360. )
  361. return {"result": "success"}
  362. class ModelProviderModelDisableApi(Resource):
  363. @setup_required
  364. @login_required
  365. @account_initialization_required
  366. def patch(self, provider: str):
  367. tenant_id = current_user.current_tenant_id
  368. parser = reqparse.RequestParser()
  369. parser.add_argument("model", type=str, required=True, nullable=False, location="json")
  370. parser.add_argument(
  371. "model_type",
  372. type=str,
  373. required=True,
  374. nullable=False,
  375. choices=[mt.value for mt in ModelType],
  376. location="json",
  377. )
  378. args = parser.parse_args()
  379. model_provider_service = ModelProviderService()
  380. model_provider_service.disable_model(
  381. tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"]
  382. )
  383. return {"result": "success"}
  384. class ModelProviderModelValidateApi(Resource):
  385. @setup_required
  386. @login_required
  387. @account_initialization_required
  388. def post(self, provider: str):
  389. tenant_id = current_user.current_tenant_id
  390. parser = reqparse.RequestParser()
  391. parser.add_argument("model", type=str, required=True, nullable=False, location="json")
  392. parser.add_argument(
  393. "model_type",
  394. type=str,
  395. required=True,
  396. nullable=False,
  397. choices=[mt.value for mt in ModelType],
  398. location="json",
  399. )
  400. parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
  401. args = parser.parse_args()
  402. model_provider_service = ModelProviderService()
  403. result = True
  404. error = ""
  405. try:
  406. model_provider_service.validate_model_credentials(
  407. tenant_id=tenant_id,
  408. provider=provider,
  409. model=args["model"],
  410. model_type=args["model_type"],
  411. credentials=args["credentials"],
  412. )
  413. except CredentialsValidateFailedError as ex:
  414. result = False
  415. error = str(ex)
  416. response = {"result": "success" if result else "error"}
  417. if not result:
  418. response["error"] = error or ""
  419. return response
  420. class ModelProviderModelParameterRuleApi(Resource):
  421. @setup_required
  422. @login_required
  423. @account_initialization_required
  424. def get(self, provider: str):
  425. parser = reqparse.RequestParser()
  426. parser.add_argument("model", type=str, required=True, nullable=False, location="args")
  427. args = parser.parse_args()
  428. tenant_id = current_user.current_tenant_id
  429. model_provider_service = ModelProviderService()
  430. parameter_rules = model_provider_service.get_model_parameter_rules(
  431. tenant_id=tenant_id, provider=provider, model=args["model"]
  432. )
  433. return jsonable_encoder({"data": parameter_rules})
  434. class ModelProviderAvailableModelApi(Resource):
  435. @setup_required
  436. @login_required
  437. @account_initialization_required
  438. def get(self, model_type):
  439. tenant_id = current_user.current_tenant_id
  440. model_provider_service = ModelProviderService()
  441. models = model_provider_service.get_models_by_model_type(tenant_id=tenant_id, model_type=model_type)
  442. return jsonable_encoder({"data": models})
  443. api.add_resource(ModelProviderModelApi, "/workspaces/current/model-providers/<path:provider>/models")
  444. api.add_resource(
  445. ModelProviderModelEnableApi,
  446. "/workspaces/current/model-providers/<path:provider>/models/enable",
  447. endpoint="model-provider-model-enable",
  448. )
  449. api.add_resource(
  450. ModelProviderModelDisableApi,
  451. "/workspaces/current/model-providers/<path:provider>/models/disable",
  452. endpoint="model-provider-model-disable",
  453. )
  454. api.add_resource(
  455. ModelProviderModelCredentialApi, "/workspaces/current/model-providers/<path:provider>/models/credentials"
  456. )
  457. api.add_resource(
  458. ModelProviderModelCredentialSwitchApi,
  459. "/workspaces/current/model-providers/<path:provider>/models/credentials/switch",
  460. )
  461. api.add_resource(
  462. ModelProviderModelValidateApi, "/workspaces/current/model-providers/<path:provider>/models/credentials/validate"
  463. )
  464. api.add_resource(
  465. ModelProviderModelParameterRuleApi, "/workspaces/current/model-providers/<path:provider>/models/parameter-rules"
  466. )
  467. api.add_resource(ModelProviderAvailableModelApi, "/workspaces/current/models/model-types/<string:model_type>")
  468. api.add_resource(DefaultModelApi, "/workspaces/current/default-model")