You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

models.py 14KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390
  1. import logging
  2. from flask_login import current_user
  3. from flask_restful import Resource, reqparse
  4. from werkzeug.exceptions import Forbidden
  5. from controllers.console import api
  6. from controllers.console.wraps import account_initialization_required, setup_required
  7. from core.model_runtime.entities.model_entities import ModelType
  8. from core.model_runtime.errors.validate import CredentialsValidateFailedError
  9. from core.model_runtime.utils.encoders import jsonable_encoder
  10. from libs.login import login_required
  11. from services.model_load_balancing_service import ModelLoadBalancingService
  12. from services.model_provider_service import ModelProviderService
  13. class DefaultModelApi(Resource):
  14. @setup_required
  15. @login_required
  16. @account_initialization_required
  17. def get(self):
  18. parser = reqparse.RequestParser()
  19. parser.add_argument(
  20. "model_type",
  21. type=str,
  22. required=True,
  23. nullable=False,
  24. choices=[mt.value for mt in ModelType],
  25. location="args",
  26. )
  27. args = parser.parse_args()
  28. tenant_id = current_user.current_tenant_id
  29. model_provider_service = ModelProviderService()
  30. default_model_entity = model_provider_service.get_default_model_of_model_type(
  31. tenant_id=tenant_id, model_type=args["model_type"]
  32. )
  33. return jsonable_encoder({"data": default_model_entity})
  34. @setup_required
  35. @login_required
  36. @account_initialization_required
  37. def post(self):
  38. if not current_user.is_admin_or_owner:
  39. raise Forbidden()
  40. parser = reqparse.RequestParser()
  41. parser.add_argument("model_settings", type=list, required=True, nullable=False, location="json")
  42. args = parser.parse_args()
  43. tenant_id = current_user.current_tenant_id
  44. model_provider_service = ModelProviderService()
  45. model_settings = args["model_settings"]
  46. for model_setting in model_settings:
  47. if "model_type" not in model_setting or model_setting["model_type"] not in [mt.value for mt in ModelType]:
  48. raise ValueError("invalid model type")
  49. if "provider" not in model_setting:
  50. continue
  51. if "model" not in model_setting:
  52. raise ValueError("invalid model")
  53. try:
  54. model_provider_service.update_default_model_of_model_type(
  55. tenant_id=tenant_id,
  56. model_type=model_setting["model_type"],
  57. provider=model_setting["provider"],
  58. model=model_setting["model"],
  59. )
  60. except Exception as ex:
  61. logging.exception(
  62. "Failed to update default model, model type: %s, model: %s",
  63. model_setting["model_type"],
  64. model_setting.get("model"),
  65. )
  66. raise ex
  67. return {"result": "success"}
  68. class ModelProviderModelApi(Resource):
  69. @setup_required
  70. @login_required
  71. @account_initialization_required
  72. def get(self, provider):
  73. tenant_id = current_user.current_tenant_id
  74. model_provider_service = ModelProviderService()
  75. models = model_provider_service.get_models_by_provider(tenant_id=tenant_id, provider=provider)
  76. return jsonable_encoder({"data": models})
  77. @setup_required
  78. @login_required
  79. @account_initialization_required
  80. def post(self, provider: str):
  81. if not current_user.is_admin_or_owner:
  82. raise Forbidden()
  83. tenant_id = current_user.current_tenant_id
  84. parser = reqparse.RequestParser()
  85. parser.add_argument("model", type=str, required=True, nullable=False, location="json")
  86. parser.add_argument(
  87. "model_type",
  88. type=str,
  89. required=True,
  90. nullable=False,
  91. choices=[mt.value for mt in ModelType],
  92. location="json",
  93. )
  94. parser.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
  95. parser.add_argument("load_balancing", type=dict, required=False, nullable=True, location="json")
  96. parser.add_argument("config_from", type=str, required=False, nullable=True, location="json")
  97. args = parser.parse_args()
  98. model_load_balancing_service = ModelLoadBalancingService()
  99. if (
  100. "load_balancing" in args
  101. and args["load_balancing"]
  102. and "enabled" in args["load_balancing"]
  103. and args["load_balancing"]["enabled"]
  104. ):
  105. if "configs" not in args["load_balancing"]:
  106. raise ValueError("invalid load balancing configs")
  107. # save load balancing configs
  108. model_load_balancing_service.update_load_balancing_configs(
  109. tenant_id=tenant_id,
  110. provider=provider,
  111. model=args["model"],
  112. model_type=args["model_type"],
  113. configs=args["load_balancing"]["configs"],
  114. )
  115. # enable load balancing
  116. model_load_balancing_service.enable_model_load_balancing(
  117. tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"]
  118. )
  119. else:
  120. # disable load balancing
  121. model_load_balancing_service.disable_model_load_balancing(
  122. tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"]
  123. )
  124. if args.get("config_from", "") != "predefined-model":
  125. model_provider_service = ModelProviderService()
  126. try:
  127. model_provider_service.save_model_credentials(
  128. tenant_id=tenant_id,
  129. provider=provider,
  130. model=args["model"],
  131. model_type=args["model_type"],
  132. credentials=args["credentials"],
  133. )
  134. except CredentialsValidateFailedError as ex:
  135. logging.exception(
  136. "Failed to save model credentials, tenant_id: %s, model: %s, model_type: %s",
  137. tenant_id,
  138. args.get("model"),
  139. args.get("model_type"),
  140. )
  141. raise ValueError(str(ex))
  142. return {"result": "success"}, 200
  143. @setup_required
  144. @login_required
  145. @account_initialization_required
  146. def delete(self, provider: str):
  147. if not current_user.is_admin_or_owner:
  148. raise Forbidden()
  149. tenant_id = current_user.current_tenant_id
  150. parser = reqparse.RequestParser()
  151. parser.add_argument("model", type=str, required=True, nullable=False, location="json")
  152. parser.add_argument(
  153. "model_type",
  154. type=str,
  155. required=True,
  156. nullable=False,
  157. choices=[mt.value for mt in ModelType],
  158. location="json",
  159. )
  160. args = parser.parse_args()
  161. model_provider_service = ModelProviderService()
  162. model_provider_service.remove_model_credentials(
  163. tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"]
  164. )
  165. return {"result": "success"}, 204
  166. class ModelProviderModelCredentialApi(Resource):
  167. @setup_required
  168. @login_required
  169. @account_initialization_required
  170. def get(self, provider: str):
  171. tenant_id = current_user.current_tenant_id
  172. parser = reqparse.RequestParser()
  173. parser.add_argument("model", type=str, required=True, nullable=False, location="args")
  174. parser.add_argument(
  175. "model_type",
  176. type=str,
  177. required=True,
  178. nullable=False,
  179. choices=[mt.value for mt in ModelType],
  180. location="args",
  181. )
  182. args = parser.parse_args()
  183. model_provider_service = ModelProviderService()
  184. credentials = model_provider_service.get_model_credentials(
  185. tenant_id=tenant_id, provider=provider, model_type=args["model_type"], model=args["model"]
  186. )
  187. model_load_balancing_service = ModelLoadBalancingService()
  188. is_load_balancing_enabled, load_balancing_configs = model_load_balancing_service.get_load_balancing_configs(
  189. tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"]
  190. )
  191. return {
  192. "credentials": credentials,
  193. "load_balancing": {"enabled": is_load_balancing_enabled, "configs": load_balancing_configs},
  194. }
  195. class ModelProviderModelEnableApi(Resource):
  196. @setup_required
  197. @login_required
  198. @account_initialization_required
  199. def patch(self, provider: str):
  200. tenant_id = current_user.current_tenant_id
  201. parser = reqparse.RequestParser()
  202. parser.add_argument("model", type=str, required=True, nullable=False, location="json")
  203. parser.add_argument(
  204. "model_type",
  205. type=str,
  206. required=True,
  207. nullable=False,
  208. choices=[mt.value for mt in ModelType],
  209. location="json",
  210. )
  211. args = parser.parse_args()
  212. model_provider_service = ModelProviderService()
  213. model_provider_service.enable_model(
  214. tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"]
  215. )
  216. return {"result": "success"}
  217. class ModelProviderModelDisableApi(Resource):
  218. @setup_required
  219. @login_required
  220. @account_initialization_required
  221. def patch(self, provider: str):
  222. tenant_id = current_user.current_tenant_id
  223. parser = reqparse.RequestParser()
  224. parser.add_argument("model", type=str, required=True, nullable=False, location="json")
  225. parser.add_argument(
  226. "model_type",
  227. type=str,
  228. required=True,
  229. nullable=False,
  230. choices=[mt.value for mt in ModelType],
  231. location="json",
  232. )
  233. args = parser.parse_args()
  234. model_provider_service = ModelProviderService()
  235. model_provider_service.disable_model(
  236. tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"]
  237. )
  238. return {"result": "success"}
  239. class ModelProviderModelValidateApi(Resource):
  240. @setup_required
  241. @login_required
  242. @account_initialization_required
  243. def post(self, provider: str):
  244. tenant_id = current_user.current_tenant_id
  245. parser = reqparse.RequestParser()
  246. parser.add_argument("model", type=str, required=True, nullable=False, location="json")
  247. parser.add_argument(
  248. "model_type",
  249. type=str,
  250. required=True,
  251. nullable=False,
  252. choices=[mt.value for mt in ModelType],
  253. location="json",
  254. )
  255. parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
  256. args = parser.parse_args()
  257. model_provider_service = ModelProviderService()
  258. result = True
  259. error = ""
  260. try:
  261. model_provider_service.model_credentials_validate(
  262. tenant_id=tenant_id,
  263. provider=provider,
  264. model=args["model"],
  265. model_type=args["model_type"],
  266. credentials=args["credentials"],
  267. )
  268. except CredentialsValidateFailedError as ex:
  269. result = False
  270. error = str(ex)
  271. response = {"result": "success" if result else "error"}
  272. if not result:
  273. response["error"] = error or ""
  274. return response
  275. class ModelProviderModelParameterRuleApi(Resource):
  276. @setup_required
  277. @login_required
  278. @account_initialization_required
  279. def get(self, provider: str):
  280. parser = reqparse.RequestParser()
  281. parser.add_argument("model", type=str, required=True, nullable=False, location="args")
  282. args = parser.parse_args()
  283. tenant_id = current_user.current_tenant_id
  284. model_provider_service = ModelProviderService()
  285. parameter_rules = model_provider_service.get_model_parameter_rules(
  286. tenant_id=tenant_id, provider=provider, model=args["model"]
  287. )
  288. return jsonable_encoder({"data": parameter_rules})
  289. class ModelProviderAvailableModelApi(Resource):
  290. @setup_required
  291. @login_required
  292. @account_initialization_required
  293. def get(self, model_type):
  294. tenant_id = current_user.current_tenant_id
  295. model_provider_service = ModelProviderService()
  296. models = model_provider_service.get_models_by_model_type(tenant_id=tenant_id, model_type=model_type)
  297. return jsonable_encoder({"data": models})
  298. api.add_resource(ModelProviderModelApi, "/workspaces/current/model-providers/<path:provider>/models")
  299. api.add_resource(
  300. ModelProviderModelEnableApi,
  301. "/workspaces/current/model-providers/<path:provider>/models/enable",
  302. endpoint="model-provider-model-enable",
  303. )
  304. api.add_resource(
  305. ModelProviderModelDisableApi,
  306. "/workspaces/current/model-providers/<path:provider>/models/disable",
  307. endpoint="model-provider-model-disable",
  308. )
  309. api.add_resource(
  310. ModelProviderModelCredentialApi, "/workspaces/current/model-providers/<path:provider>/models/credentials"
  311. )
  312. api.add_resource(
  313. ModelProviderModelValidateApi, "/workspaces/current/model-providers/<path:provider>/models/credentials/validate"
  314. )
  315. api.add_resource(
  316. ModelProviderModelParameterRuleApi, "/workspaces/current/model-providers/<path:provider>/models/parameter-rules"
  317. )
  318. api.add_resource(ModelProviderAvailableModelApi, "/workspaces/current/models/model-types/<string:model_type>")
  319. api.add_resource(DefaultModelApi, "/workspaces/current/default-model")