Du kan inte välja fler än 25 ämnen Ämnen måste starta med en bokstav eller siffra, kan innehålla bindestreck ('-') och vara max 35 tecken långa.

models.py 13KB


  1. import logging
  2. from flask_login import current_user
  3. from flask_restful import Resource, reqparse
  4. from werkzeug.exceptions import Forbidden
  5. from controllers.console import api
  6. from controllers.console.setup import setup_required
  7. from controllers.console.wraps import account_initialization_required
  8. from core.model_runtime.entities.model_entities import ModelType
  9. from core.model_runtime.errors.validate import CredentialsValidateFailedError
  10. from core.model_runtime.utils.encoders import jsonable_encoder
  11. from libs.login import login_required
  12. from services.model_load_balancing_service import ModelLoadBalancingService
  13. from services.model_provider_service import ModelProviderService
  14. class DefaultModelApi(Resource):
  15. @setup_required
  16. @login_required
  17. @account_initialization_required
  18. def get(self):
  19. parser = reqparse.RequestParser()
  20. parser.add_argument('model_type', type=str, required=True, nullable=False,
  21. choices=[mt.value for mt in ModelType], location='args')
  22. args = parser.parse_args()
  23. tenant_id = current_user.current_tenant_id
  24. model_provider_service = ModelProviderService()
  25. default_model_entity = model_provider_service.get_default_model_of_model_type(
  26. tenant_id=tenant_id,
  27. model_type=args['model_type']
  28. )
  29. return jsonable_encoder({
  30. "data": default_model_entity
  31. })
  32. @setup_required
  33. @login_required
  34. @account_initialization_required
  35. def post(self):
  36. if not current_user.is_admin_or_owner:
  37. raise Forbidden()
  38. parser = reqparse.RequestParser()
  39. parser.add_argument('model_settings', type=list, required=True, nullable=False, location='json')
  40. args = parser.parse_args()
  41. tenant_id = current_user.current_tenant_id
  42. model_provider_service = ModelProviderService()
  43. model_settings = args['model_settings']
  44. for model_setting in model_settings:
  45. if 'model_type' not in model_setting or model_setting['model_type'] not in [mt.value for mt in ModelType]:
  46. raise ValueError('invalid model type')
  47. if 'provider' not in model_setting:
  48. continue
  49. if 'model' not in model_setting:
  50. raise ValueError('invalid model')
  51. try:
  52. model_provider_service.update_default_model_of_model_type(
  53. tenant_id=tenant_id,
  54. model_type=model_setting['model_type'],
  55. provider=model_setting['provider'],
  56. model=model_setting['model']
  57. )
  58. except Exception:
  59. logging.warning(f"{model_setting['model_type']} save error")
  60. return {'result': 'success'}
  61. class ModelProviderModelApi(Resource):
  62. @setup_required
  63. @login_required
  64. @account_initialization_required
  65. def get(self, provider):
  66. tenant_id = current_user.current_tenant_id
  67. model_provider_service = ModelProviderService()
  68. models = model_provider_service.get_models_by_provider(
  69. tenant_id=tenant_id,
  70. provider=provider
  71. )
  72. return jsonable_encoder({
  73. "data": models
  74. })
  75. @setup_required
  76. @login_required
  77. @account_initialization_required
  78. def post(self, provider: str):
  79. if not current_user.is_admin_or_owner:
  80. raise Forbidden()
  81. tenant_id = current_user.current_tenant_id
  82. parser = reqparse.RequestParser()
  83. parser.add_argument('model', type=str, required=True, nullable=False, location='json')
  84. parser.add_argument('model_type', type=str, required=True, nullable=False,
  85. choices=[mt.value for mt in ModelType], location='json')
  86. parser.add_argument('credentials', type=dict, required=False, nullable=True, location='json')
  87. parser.add_argument('load_balancing', type=dict, required=False, nullable=True, location='json')
  88. parser.add_argument('config_from', type=str, required=False, nullable=True, location='json')
  89. args = parser.parse_args()
  90. model_load_balancing_service = ModelLoadBalancingService()
  91. if ('load_balancing' in args and args['load_balancing'] and
  92. 'enabled' in args['load_balancing'] and args['load_balancing']['enabled']):
  93. if 'configs' not in args['load_balancing']:
  94. raise ValueError('invalid load balancing configs')
  95. # save load balancing configs
  96. model_load_balancing_service.update_load_balancing_configs(
  97. tenant_id=tenant_id,
  98. provider=provider,
  99. model=args['model'],
  100. model_type=args['model_type'],
  101. configs=args['load_balancing']['configs']
  102. )
  103. # enable load balancing
  104. model_load_balancing_service.enable_model_load_balancing(
  105. tenant_id=tenant_id,
  106. provider=provider,
  107. model=args['model'],
  108. model_type=args['model_type']
  109. )
  110. else:
  111. # disable load balancing
  112. model_load_balancing_service.disable_model_load_balancing(
  113. tenant_id=tenant_id,
  114. provider=provider,
  115. model=args['model'],
  116. model_type=args['model_type']
  117. )
  118. if args.get('config_from', '') != 'predefined-model':
  119. model_provider_service = ModelProviderService()
  120. try:
  121. model_provider_service.save_model_credentials(
  122. tenant_id=tenant_id,
  123. provider=provider,
  124. model=args['model'],
  125. model_type=args['model_type'],
  126. credentials=args['credentials']
  127. )
  128. except CredentialsValidateFailedError as ex:
  129. logging.exception(f"save model credentials error: {ex}")
  130. raise ValueError(str(ex))
  131. return {'result': 'success'}, 200
  132. @setup_required
  133. @login_required
  134. @account_initialization_required
  135. def delete(self, provider: str):
  136. if not current_user.is_admin_or_owner:
  137. raise Forbidden()
  138. tenant_id = current_user.current_tenant_id
  139. parser = reqparse.RequestParser()
  140. parser.add_argument('model', type=str, required=True, nullable=False, location='json')
  141. parser.add_argument('model_type', type=str, required=True, nullable=False,
  142. choices=[mt.value for mt in ModelType], location='json')
  143. args = parser.parse_args()
  144. model_provider_service = ModelProviderService()
  145. model_provider_service.remove_model_credentials(
  146. tenant_id=tenant_id,
  147. provider=provider,
  148. model=args['model'],
  149. model_type=args['model_type']
  150. )
  151. return {'result': 'success'}, 204
  152. class ModelProviderModelCredentialApi(Resource):
  153. @setup_required
  154. @login_required
  155. @account_initialization_required
  156. def get(self, provider: str):
  157. tenant_id = current_user.current_tenant_id
  158. parser = reqparse.RequestParser()
  159. parser.add_argument('model', type=str, required=True, nullable=False, location='args')
  160. parser.add_argument('model_type', type=str, required=True, nullable=False,
  161. choices=[mt.value for mt in ModelType], location='args')
  162. args = parser.parse_args()
  163. model_provider_service = ModelProviderService()
  164. credentials = model_provider_service.get_model_credentials(
  165. tenant_id=tenant_id,
  166. provider=provider,
  167. model_type=args['model_type'],
  168. model=args['model']
  169. )
  170. model_load_balancing_service = ModelLoadBalancingService()
  171. is_load_balancing_enabled, load_balancing_configs = model_load_balancing_service.get_load_balancing_configs(
  172. tenant_id=tenant_id,
  173. provider=provider,
  174. model=args['model'],
  175. model_type=args['model_type']
  176. )
  177. return {
  178. "credentials": credentials,
  179. "load_balancing": {
  180. "enabled": is_load_balancing_enabled,
  181. "configs": load_balancing_configs
  182. }
  183. }
  184. class ModelProviderModelEnableApi(Resource):
  185. @setup_required
  186. @login_required
  187. @account_initialization_required
  188. def patch(self, provider: str):
  189. tenant_id = current_user.current_tenant_id
  190. parser = reqparse.RequestParser()
  191. parser.add_argument('model', type=str, required=True, nullable=False, location='json')
  192. parser.add_argument('model_type', type=str, required=True, nullable=False,
  193. choices=[mt.value for mt in ModelType], location='json')
  194. args = parser.parse_args()
  195. model_provider_service = ModelProviderService()
  196. model_provider_service.enable_model(
  197. tenant_id=tenant_id,
  198. provider=provider,
  199. model=args['model'],
  200. model_type=args['model_type']
  201. )
  202. return {'result': 'success'}
  203. class ModelProviderModelDisableApi(Resource):
  204. @setup_required
  205. @login_required
  206. @account_initialization_required
  207. def patch(self, provider: str):
  208. tenant_id = current_user.current_tenant_id
  209. parser = reqparse.RequestParser()
  210. parser.add_argument('model', type=str, required=True, nullable=False, location='json')
  211. parser.add_argument('model_type', type=str, required=True, nullable=False,
  212. choices=[mt.value for mt in ModelType], location='json')
  213. args = parser.parse_args()
  214. model_provider_service = ModelProviderService()
  215. model_provider_service.disable_model(
  216. tenant_id=tenant_id,
  217. provider=provider,
  218. model=args['model'],
  219. model_type=args['model_type']
  220. )
  221. return {'result': 'success'}
  222. class ModelProviderModelValidateApi(Resource):
  223. @setup_required
  224. @login_required
  225. @account_initialization_required
  226. def post(self, provider: str):
  227. tenant_id = current_user.current_tenant_id
  228. parser = reqparse.RequestParser()
  229. parser.add_argument('model', type=str, required=True, nullable=False, location='json')
  230. parser.add_argument('model_type', type=str, required=True, nullable=False,
  231. choices=[mt.value for mt in ModelType], location='json')
  232. parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json')
  233. args = parser.parse_args()
  234. model_provider_service = ModelProviderService()
  235. result = True
  236. error = None
  237. try:
  238. model_provider_service.model_credentials_validate(
  239. tenant_id=tenant_id,
  240. provider=provider,
  241. model=args['model'],
  242. model_type=args['model_type'],
  243. credentials=args['credentials']
  244. )
  245. except CredentialsValidateFailedError as ex:
  246. result = False
  247. error = str(ex)
  248. response = {'result': 'success' if result else 'error'}
  249. if not result:
  250. response['error'] = error
  251. return response
  252. class ModelProviderModelParameterRuleApi(Resource):
  253. @setup_required
  254. @login_required
  255. @account_initialization_required
  256. def get(self, provider: str):
  257. parser = reqparse.RequestParser()
  258. parser.add_argument('model', type=str, required=True, nullable=False, location='args')
  259. args = parser.parse_args()
  260. tenant_id = current_user.current_tenant_id
  261. model_provider_service = ModelProviderService()
  262. parameter_rules = model_provider_service.get_model_parameter_rules(
  263. tenant_id=tenant_id,
  264. provider=provider,
  265. model=args['model']
  266. )
  267. return jsonable_encoder({
  268. "data": parameter_rules
  269. })
  270. class ModelProviderAvailableModelApi(Resource):
  271. @setup_required
  272. @login_required
  273. @account_initialization_required
  274. def get(self, model_type):
  275. tenant_id = current_user.current_tenant_id
  276. model_provider_service = ModelProviderService()
  277. models = model_provider_service.get_models_by_model_type(
  278. tenant_id=tenant_id,
  279. model_type=model_type
  280. )
  281. return jsonable_encoder({
  282. "data": models
  283. })
  284. api.add_resource(ModelProviderModelApi, '/workspaces/current/model-providers/<string:provider>/models')
  285. api.add_resource(ModelProviderModelEnableApi, '/workspaces/current/model-providers/<string:provider>/models/enable',
  286. endpoint='model-provider-model-enable')
  287. api.add_resource(ModelProviderModelDisableApi, '/workspaces/current/model-providers/<string:provider>/models/disable',
  288. endpoint='model-provider-model-disable')
  289. api.add_resource(ModelProviderModelCredentialApi,
  290. '/workspaces/current/model-providers/<string:provider>/models/credentials')
  291. api.add_resource(ModelProviderModelValidateApi,
  292. '/workspaces/current/model-providers/<string:provider>/models/credentials/validate')
  293. api.add_resource(ModelProviderModelParameterRuleApi,
  294. '/workspaces/current/model-providers/<string:provider>/models/parameter-rules')
  295. api.add_resource(ModelProviderAvailableModelApi, '/workspaces/current/models/model-types/<string:model_type>')
  296. api.add_resource(DefaultModelApi, '/workspaces/current/default-model')