| @@ -1,3 +1 @@ | |||
| @@ -2,7 +2,7 @@ from flask import Blueprint | |||
| from libs.external_api import ExternalApi | |||
| bp = Blueprint('console', __name__, url_prefix='/console/api') | |||
| bp = Blueprint("console", __name__, url_prefix="/console/api") | |||
| api = ExternalApi(bp) | |||
| # Import other controllers | |||
| @@ -15,24 +15,24 @@ from models.model import App, InstalledApp, RecommendedApp | |||
| def admin_required(view): | |||
| @wraps(view) | |||
| def decorated(*args, **kwargs): | |||
| if not os.getenv('ADMIN_API_KEY'): | |||
| raise Unauthorized('API key is invalid.') | |||
| if not os.getenv("ADMIN_API_KEY"): | |||
| raise Unauthorized("API key is invalid.") | |||
| auth_header = request.headers.get('Authorization') | |||
| auth_header = request.headers.get("Authorization") | |||
| if auth_header is None: | |||
| raise Unauthorized('Authorization header is missing.') | |||
| raise Unauthorized("Authorization header is missing.") | |||
| if ' ' not in auth_header: | |||
| raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.') | |||
| if " " not in auth_header: | |||
| raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.") | |||
| auth_scheme, auth_token = auth_header.split(None, 1) | |||
| auth_scheme = auth_scheme.lower() | |||
| if auth_scheme != 'bearer': | |||
| raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.') | |||
| if auth_scheme != "bearer": | |||
| raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.") | |||
| if os.getenv('ADMIN_API_KEY') != auth_token: | |||
| raise Unauthorized('API key is invalid.') | |||
| if os.getenv("ADMIN_API_KEY") != auth_token: | |||
| raise Unauthorized("API key is invalid.") | |||
| return view(*args, **kwargs) | |||
| @@ -44,37 +44,41 @@ class InsertExploreAppListApi(Resource): | |||
| @admin_required | |||
| def post(self): | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('app_id', type=str, required=True, nullable=False, location='json') | |||
| parser.add_argument('desc', type=str, location='json') | |||
| parser.add_argument('copyright', type=str, location='json') | |||
| parser.add_argument('privacy_policy', type=str, location='json') | |||
| parser.add_argument('custom_disclaimer', type=str, location='json') | |||
| parser.add_argument('language', type=supported_language, required=True, nullable=False, location='json') | |||
| parser.add_argument('category', type=str, required=True, nullable=False, location='json') | |||
| parser.add_argument('position', type=int, required=True, nullable=False, location='json') | |||
| parser.add_argument("app_id", type=str, required=True, nullable=False, location="json") | |||
| parser.add_argument("desc", type=str, location="json") | |||
| parser.add_argument("copyright", type=str, location="json") | |||
| parser.add_argument("privacy_policy", type=str, location="json") | |||
| parser.add_argument("custom_disclaimer", type=str, location="json") | |||
| parser.add_argument("language", type=supported_language, required=True, nullable=False, location="json") | |||
| parser.add_argument("category", type=str, required=True, nullable=False, location="json") | |||
| parser.add_argument("position", type=int, required=True, nullable=False, location="json") | |||
| args = parser.parse_args() | |||
| app = App.query.filter(App.id == args['app_id']).first() | |||
| app = App.query.filter(App.id == args["app_id"]).first() | |||
| if not app: | |||
| raise NotFound(f'App \'{args["app_id"]}\' is not found') | |||
| site = app.site | |||
| if not site: | |||
| desc = args['desc'] if args['desc'] else '' | |||
| copy_right = args['copyright'] if args['copyright'] else '' | |||
| privacy_policy = args['privacy_policy'] if args['privacy_policy'] else '' | |||
| custom_disclaimer = args['custom_disclaimer'] if args['custom_disclaimer'] else '' | |||
| desc = args["desc"] if args["desc"] else "" | |||
| copy_right = args["copyright"] if args["copyright"] else "" | |||
| privacy_policy = args["privacy_policy"] if args["privacy_policy"] else "" | |||
| custom_disclaimer = args["custom_disclaimer"] if args["custom_disclaimer"] else "" | |||
| else: | |||
| desc = site.description if site.description else \ | |||
| args['desc'] if args['desc'] else '' | |||
| copy_right = site.copyright if site.copyright else \ | |||
| args['copyright'] if args['copyright'] else '' | |||
| privacy_policy = site.privacy_policy if site.privacy_policy else \ | |||
| args['privacy_policy'] if args['privacy_policy'] else '' | |||
| custom_disclaimer = site.custom_disclaimer if site.custom_disclaimer else \ | |||
| args['custom_disclaimer'] if args['custom_disclaimer'] else '' | |||
| desc = site.description if site.description else args["desc"] if args["desc"] else "" | |||
| copy_right = site.copyright if site.copyright else args["copyright"] if args["copyright"] else "" | |||
| privacy_policy = ( | |||
| site.privacy_policy if site.privacy_policy else args["privacy_policy"] if args["privacy_policy"] else "" | |||
| ) | |||
| custom_disclaimer = ( | |||
| site.custom_disclaimer | |||
| if site.custom_disclaimer | |||
| else args["custom_disclaimer"] | |||
| if args["custom_disclaimer"] | |||
| else "" | |||
| ) | |||
| recommended_app = RecommendedApp.query.filter(RecommendedApp.app_id == args['app_id']).first() | |||
| recommended_app = RecommendedApp.query.filter(RecommendedApp.app_id == args["app_id"]).first() | |||
| if not recommended_app: | |||
| recommended_app = RecommendedApp( | |||
| @@ -83,9 +87,9 @@ class InsertExploreAppListApi(Resource): | |||
| copyright=copy_right, | |||
| privacy_policy=privacy_policy, | |||
| custom_disclaimer=custom_disclaimer, | |||
| language=args['language'], | |||
| category=args['category'], | |||
| position=args['position'] | |||
| language=args["language"], | |||
| category=args["category"], | |||
| position=args["position"], | |||
| ) | |||
| db.session.add(recommended_app) | |||
| @@ -93,21 +97,21 @@ class InsertExploreAppListApi(Resource): | |||
| app.is_public = True | |||
| db.session.commit() | |||
| return {'result': 'success'}, 201 | |||
| return {"result": "success"}, 201 | |||
| else: | |||
| recommended_app.description = desc | |||
| recommended_app.copyright = copy_right | |||
| recommended_app.privacy_policy = privacy_policy | |||
| recommended_app.custom_disclaimer = custom_disclaimer | |||
| recommended_app.language = args['language'] | |||
| recommended_app.category = args['category'] | |||
| recommended_app.position = args['position'] | |||
| recommended_app.language = args["language"] | |||
| recommended_app.category = args["category"] | |||
| recommended_app.position = args["position"] | |||
| app.is_public = True | |||
| db.session.commit() | |||
| return {'result': 'success'}, 200 | |||
| return {"result": "success"}, 200 | |||
| class InsertExploreAppApi(Resource): | |||
| @@ -116,15 +120,14 @@ class InsertExploreAppApi(Resource): | |||
| def delete(self, app_id): | |||
| recommended_app = RecommendedApp.query.filter(RecommendedApp.app_id == str(app_id)).first() | |||
| if not recommended_app: | |||
| return {'result': 'success'}, 204 | |||
| return {"result": "success"}, 204 | |||
| app = App.query.filter(App.id == recommended_app.app_id).first() | |||
| if app: | |||
| app.is_public = False | |||
| installed_apps = InstalledApp.query.filter( | |||
| InstalledApp.app_id == recommended_app.app_id, | |||
| InstalledApp.tenant_id != InstalledApp.app_owner_tenant_id | |||
| InstalledApp.app_id == recommended_app.app_id, InstalledApp.tenant_id != InstalledApp.app_owner_tenant_id | |||
| ).all() | |||
| for installed_app in installed_apps: | |||
| @@ -133,8 +136,8 @@ class InsertExploreAppApi(Resource): | |||
| db.session.delete(recommended_app) | |||
| db.session.commit() | |||
| return {'result': 'success'}, 204 | |||
| return {"result": "success"}, 204 | |||
| api.add_resource(InsertExploreAppListApi, '/admin/insert-explore-apps') | |||
| api.add_resource(InsertExploreAppApi, '/admin/insert-explore-apps/<uuid:app_id>') | |||
| api.add_resource(InsertExploreAppListApi, "/admin/insert-explore-apps") | |||
| api.add_resource(InsertExploreAppApi, "/admin/insert-explore-apps/<uuid:app_id>") | |||
| @@ -14,26 +14,21 @@ from .setup import setup_required | |||
| from .wraps import account_initialization_required | |||
| api_key_fields = { | |||
| 'id': fields.String, | |||
| 'type': fields.String, | |||
| 'token': fields.String, | |||
| 'last_used_at': TimestampField, | |||
| 'created_at': TimestampField | |||
| "id": fields.String, | |||
| "type": fields.String, | |||
| "token": fields.String, | |||
| "last_used_at": TimestampField, | |||
| "created_at": TimestampField, | |||
| } | |||
| api_key_list = { | |||
| 'data': fields.List(fields.Nested(api_key_fields), attribute="items") | |||
| } | |||
| api_key_list = {"data": fields.List(fields.Nested(api_key_fields), attribute="items")} | |||
| def _get_resource(resource_id, tenant_id, resource_model): | |||
| resource = resource_model.query.filter_by( | |||
| id=resource_id, tenant_id=tenant_id | |||
| ).first() | |||
| resource = resource_model.query.filter_by(id=resource_id, tenant_id=tenant_id).first() | |||
| if resource is None: | |||
| flask_restful.abort( | |||
| 404, message=f"{resource_model.__name__} not found.") | |||
| flask_restful.abort(404, message=f"{resource_model.__name__} not found.") | |||
| return resource | |||
| @@ -50,30 +45,32 @@ class BaseApiKeyListResource(Resource): | |||
| @marshal_with(api_key_list) | |||
| def get(self, resource_id): | |||
| resource_id = str(resource_id) | |||
| _get_resource(resource_id, current_user.current_tenant_id, | |||
| self.resource_model) | |||
| keys = db.session.query(ApiToken). \ | |||
| filter(ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id). \ | |||
| all() | |||
| _get_resource(resource_id, current_user.current_tenant_id, self.resource_model) | |||
| keys = ( | |||
| db.session.query(ApiToken) | |||
| .filter(ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id) | |||
| .all() | |||
| ) | |||
| return {"items": keys} | |||
| @marshal_with(api_key_fields) | |||
| def post(self, resource_id): | |||
| resource_id = str(resource_id) | |||
| _get_resource(resource_id, current_user.current_tenant_id, | |||
| self.resource_model) | |||
| _get_resource(resource_id, current_user.current_tenant_id, self.resource_model) | |||
| if not current_user.is_admin_or_owner: | |||
| raise Forbidden() | |||
| current_key_count = db.session.query(ApiToken). \ | |||
| filter(ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id). \ | |||
| count() | |||
| current_key_count = ( | |||
| db.session.query(ApiToken) | |||
| .filter(ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id) | |||
| .count() | |||
| ) | |||
| if current_key_count >= self.max_keys: | |||
| flask_restful.abort( | |||
| 400, | |||
| message=f"Cannot create more than {self.max_keys} API keys for this resource type.", | |||
| code='max_keys_exceeded' | |||
| code="max_keys_exceeded", | |||
| ) | |||
| key = ApiToken.generate_api_key(self.token_prefix, 24) | |||
| @@ -97,79 +94,78 @@ class BaseApiKeyResource(Resource): | |||
| def delete(self, resource_id, api_key_id): | |||
| resource_id = str(resource_id) | |||
| api_key_id = str(api_key_id) | |||
| _get_resource(resource_id, current_user.current_tenant_id, | |||
| self.resource_model) | |||
| _get_resource(resource_id, current_user.current_tenant_id, self.resource_model) | |||
| # The role of the current user in the ta table must be admin or owner | |||
| if not current_user.is_admin_or_owner: | |||
| raise Forbidden() | |||
| key = db.session.query(ApiToken). \ | |||
| filter(getattr(ApiToken, self.resource_id_field) == resource_id, ApiToken.type == self.resource_type, ApiToken.id == api_key_id). \ | |||
| first() | |||
| key = ( | |||
| db.session.query(ApiToken) | |||
| .filter( | |||
| getattr(ApiToken, self.resource_id_field) == resource_id, | |||
| ApiToken.type == self.resource_type, | |||
| ApiToken.id == api_key_id, | |||
| ) | |||
| .first() | |||
| ) | |||
| if key is None: | |||
| flask_restful.abort(404, message='API key not found') | |||
| flask_restful.abort(404, message="API key not found") | |||
| db.session.query(ApiToken).filter(ApiToken.id == api_key_id).delete() | |||
| db.session.commit() | |||
| return {'result': 'success'}, 204 | |||
| return {"result": "success"}, 204 | |||
| class AppApiKeyListResource(BaseApiKeyListResource): | |||
| def after_request(self, resp): | |||
| resp.headers['Access-Control-Allow-Origin'] = '*' | |||
| resp.headers['Access-Control-Allow-Credentials'] = 'true' | |||
| resp.headers["Access-Control-Allow-Origin"] = "*" | |||
| resp.headers["Access-Control-Allow-Credentials"] = "true" | |||
| return resp | |||
| resource_type = 'app' | |||
| resource_type = "app" | |||
| resource_model = App | |||
| resource_id_field = 'app_id' | |||
| token_prefix = 'app-' | |||
| resource_id_field = "app_id" | |||
| token_prefix = "app-" | |||
| class AppApiKeyResource(BaseApiKeyResource): | |||
| def after_request(self, resp): | |||
| resp.headers['Access-Control-Allow-Origin'] = '*' | |||
| resp.headers['Access-Control-Allow-Credentials'] = 'true' | |||
| resp.headers["Access-Control-Allow-Origin"] = "*" | |||
| resp.headers["Access-Control-Allow-Credentials"] = "true" | |||
| return resp | |||
| resource_type = 'app' | |||
| resource_type = "app" | |||
| resource_model = App | |||
| resource_id_field = 'app_id' | |||
| resource_id_field = "app_id" | |||
| class DatasetApiKeyListResource(BaseApiKeyListResource): | |||
| def after_request(self, resp): | |||
| resp.headers['Access-Control-Allow-Origin'] = '*' | |||
| resp.headers['Access-Control-Allow-Credentials'] = 'true' | |||
| resp.headers["Access-Control-Allow-Origin"] = "*" | |||
| resp.headers["Access-Control-Allow-Credentials"] = "true" | |||
| return resp | |||
| resource_type = 'dataset' | |||
| resource_type = "dataset" | |||
| resource_model = Dataset | |||
| resource_id_field = 'dataset_id' | |||
| token_prefix = 'ds-' | |||
| resource_id_field = "dataset_id" | |||
| token_prefix = "ds-" | |||
| class DatasetApiKeyResource(BaseApiKeyResource): | |||
| def after_request(self, resp): | |||
| resp.headers['Access-Control-Allow-Origin'] = '*' | |||
| resp.headers['Access-Control-Allow-Credentials'] = 'true' | |||
| resp.headers["Access-Control-Allow-Origin"] = "*" | |||
| resp.headers["Access-Control-Allow-Credentials"] = "true" | |||
| return resp | |||
| resource_type = 'dataset' | |||
| resource_type = "dataset" | |||
| resource_model = Dataset | |||
| resource_id_field = 'dataset_id' | |||
| resource_id_field = "dataset_id" | |||
| api.add_resource(AppApiKeyListResource, '/apps/<uuid:resource_id>/api-keys') | |||
| api.add_resource(AppApiKeyResource, | |||
| '/apps/<uuid:resource_id>/api-keys/<uuid:api_key_id>') | |||
| api.add_resource(DatasetApiKeyListResource, | |||
| '/datasets/<uuid:resource_id>/api-keys') | |||
| api.add_resource(DatasetApiKeyResource, | |||
| '/datasets/<uuid:resource_id>/api-keys/<uuid:api_key_id>') | |||
| api.add_resource(AppApiKeyListResource, "/apps/<uuid:resource_id>/api-keys") | |||
| api.add_resource(AppApiKeyResource, "/apps/<uuid:resource_id>/api-keys/<uuid:api_key_id>") | |||
| api.add_resource(DatasetApiKeyListResource, "/datasets/<uuid:resource_id>/api-keys") | |||
| api.add_resource(DatasetApiKeyResource, "/datasets/<uuid:resource_id>/api-keys/<uuid:api_key_id>") | |||
| @@ -8,19 +8,18 @@ from services.advanced_prompt_template_service import AdvancedPromptTemplateServ | |||
| class AdvancedPromptTemplateList(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| def get(self): | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('app_mode', type=str, required=True, location='args') | |||
| parser.add_argument('model_mode', type=str, required=True, location='args') | |||
| parser.add_argument('has_context', type=str, required=False, default='true', location='args') | |||
| parser.add_argument('model_name', type=str, required=True, location='args') | |||
| parser.add_argument("app_mode", type=str, required=True, location="args") | |||
| parser.add_argument("model_mode", type=str, required=True, location="args") | |||
| parser.add_argument("has_context", type=str, required=False, default="true", location="args") | |||
| parser.add_argument("model_name", type=str, required=True, location="args") | |||
| args = parser.parse_args() | |||
| return AdvancedPromptTemplateService.get_prompt(args) | |||
| api.add_resource(AdvancedPromptTemplateList, '/app/prompt-templates') | |||
| api.add_resource(AdvancedPromptTemplateList, "/app/prompt-templates") | |||
| @@ -18,15 +18,12 @@ class AgentLogApi(Resource): | |||
| def get(self, app_model): | |||
| """Get agent logs""" | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('message_id', type=uuid_value, required=True, location='args') | |||
| parser.add_argument('conversation_id', type=uuid_value, required=True, location='args') | |||
| parser.add_argument("message_id", type=uuid_value, required=True, location="args") | |||
| parser.add_argument("conversation_id", type=uuid_value, required=True, location="args") | |||
| args = parser.parse_args() | |||
| return AgentService.get_agent_logs( | |||
| app_model, | |||
| args['conversation_id'], | |||
| args['message_id'] | |||
| ) | |||
| api.add_resource(AgentLogApi, '/apps/<uuid:app_id>/agent/logs') | |||
| return AgentService.get_agent_logs(app_model, args["conversation_id"], args["message_id"]) | |||
| api.add_resource(AgentLogApi, "/apps/<uuid:app_id>/agent/logs") | |||
| @@ -21,23 +21,23 @@ class AnnotationReplyActionApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @cloud_edition_billing_resource_check('annotation') | |||
| @cloud_edition_billing_resource_check("annotation") | |||
| def post(self, app_id, action): | |||
| if not current_user.is_editor: | |||
| raise Forbidden() | |||
| app_id = str(app_id) | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('score_threshold', required=True, type=float, location='json') | |||
| parser.add_argument('embedding_provider_name', required=True, type=str, location='json') | |||
| parser.add_argument('embedding_model_name', required=True, type=str, location='json') | |||
| parser.add_argument("score_threshold", required=True, type=float, location="json") | |||
| parser.add_argument("embedding_provider_name", required=True, type=str, location="json") | |||
| parser.add_argument("embedding_model_name", required=True, type=str, location="json") | |||
| args = parser.parse_args() | |||
| if action == 'enable': | |||
| if action == "enable": | |||
| result = AppAnnotationService.enable_app_annotation(args, app_id) | |||
| elif action == 'disable': | |||
| elif action == "disable": | |||
| result = AppAnnotationService.disable_app_annotation(app_id) | |||
| else: | |||
| raise ValueError('Unsupported annotation reply action') | |||
| raise ValueError("Unsupported annotation reply action") | |||
| return result, 200 | |||
| @@ -66,7 +66,7 @@ class AppAnnotationSettingUpdateApi(Resource): | |||
| annotation_setting_id = str(annotation_setting_id) | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('score_threshold', required=True, type=float, location='json') | |||
| parser.add_argument("score_threshold", required=True, type=float, location="json") | |||
| args = parser.parse_args() | |||
| result = AppAnnotationService.update_app_annotation_setting(app_id, annotation_setting_id, args) | |||
| @@ -77,28 +77,24 @@ class AnnotationReplyActionStatusApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @cloud_edition_billing_resource_check('annotation') | |||
| @cloud_edition_billing_resource_check("annotation") | |||
| def get(self, app_id, job_id, action): | |||
| if not current_user.is_editor: | |||
| raise Forbidden() | |||
| job_id = str(job_id) | |||
| app_annotation_job_key = '{}_app_annotation_job_{}'.format(action, str(job_id)) | |||
| app_annotation_job_key = "{}_app_annotation_job_{}".format(action, str(job_id)) | |||
| cache_result = redis_client.get(app_annotation_job_key) | |||
| if cache_result is None: | |||
| raise ValueError("The job is not exist.") | |||
| job_status = cache_result.decode() | |||
| error_msg = '' | |||
| if job_status == 'error': | |||
| app_annotation_error_key = '{}_app_annotation_error_{}'.format(action, str(job_id)) | |||
| error_msg = "" | |||
| if job_status == "error": | |||
| app_annotation_error_key = "{}_app_annotation_error_{}".format(action, str(job_id)) | |||
| error_msg = redis_client.get(app_annotation_error_key).decode() | |||
| return { | |||
| 'job_id': job_id, | |||
| 'job_status': job_status, | |||
| 'error_msg': error_msg | |||
| }, 200 | |||
| return {"job_id": job_id, "job_status": job_status, "error_msg": error_msg}, 200 | |||
| class AnnotationListApi(Resource): | |||
| @@ -109,18 +105,18 @@ class AnnotationListApi(Resource): | |||
| if not current_user.is_editor: | |||
| raise Forbidden() | |||
| page = request.args.get('page', default=1, type=int) | |||
| limit = request.args.get('limit', default=20, type=int) | |||
| keyword = request.args.get('keyword', default=None, type=str) | |||
| page = request.args.get("page", default=1, type=int) | |||
| limit = request.args.get("limit", default=20, type=int) | |||
| keyword = request.args.get("keyword", default=None, type=str) | |||
| app_id = str(app_id) | |||
| annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(app_id, page, limit, keyword) | |||
| response = { | |||
| 'data': marshal(annotation_list, annotation_fields), | |||
| 'has_more': len(annotation_list) == limit, | |||
| 'limit': limit, | |||
| 'total': total, | |||
| 'page': page | |||
| "data": marshal(annotation_list, annotation_fields), | |||
| "has_more": len(annotation_list) == limit, | |||
| "limit": limit, | |||
| "total": total, | |||
| "page": page, | |||
| } | |||
| return response, 200 | |||
| @@ -135,9 +131,7 @@ class AnnotationExportApi(Resource): | |||
| app_id = str(app_id) | |||
| annotation_list = AppAnnotationService.export_annotation_list_by_app_id(app_id) | |||
| response = { | |||
| 'data': marshal(annotation_list, annotation_fields) | |||
| } | |||
| response = {"data": marshal(annotation_list, annotation_fields)} | |||
| return response, 200 | |||
| @@ -145,7 +139,7 @@ class AnnotationCreateApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @cloud_edition_billing_resource_check('annotation') | |||
| @cloud_edition_billing_resource_check("annotation") | |||
| @marshal_with(annotation_fields) | |||
| def post(self, app_id): | |||
| if not current_user.is_editor: | |||
| @@ -153,8 +147,8 @@ class AnnotationCreateApi(Resource): | |||
| app_id = str(app_id) | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('question', required=True, type=str, location='json') | |||
| parser.add_argument('answer', required=True, type=str, location='json') | |||
| parser.add_argument("question", required=True, type=str, location="json") | |||
| parser.add_argument("answer", required=True, type=str, location="json") | |||
| args = parser.parse_args() | |||
| annotation = AppAnnotationService.insert_app_annotation_directly(args, app_id) | |||
| return annotation | |||
| @@ -164,7 +158,7 @@ class AnnotationUpdateDeleteApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @cloud_edition_billing_resource_check('annotation') | |||
| @cloud_edition_billing_resource_check("annotation") | |||
| @marshal_with(annotation_fields) | |||
| def post(self, app_id, annotation_id): | |||
| if not current_user.is_editor: | |||
| @@ -173,8 +167,8 @@ class AnnotationUpdateDeleteApi(Resource): | |||
| app_id = str(app_id) | |||
| annotation_id = str(annotation_id) | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('question', required=True, type=str, location='json') | |||
| parser.add_argument('answer', required=True, type=str, location='json') | |||
| parser.add_argument("question", required=True, type=str, location="json") | |||
| parser.add_argument("answer", required=True, type=str, location="json") | |||
| args = parser.parse_args() | |||
| annotation = AppAnnotationService.update_app_annotation_directly(args, app_id, annotation_id) | |||
| return annotation | |||
| @@ -189,29 +183,29 @@ class AnnotationUpdateDeleteApi(Resource): | |||
| app_id = str(app_id) | |||
| annotation_id = str(annotation_id) | |||
| AppAnnotationService.delete_app_annotation(app_id, annotation_id) | |||
| return {'result': 'success'}, 200 | |||
| return {"result": "success"}, 200 | |||
| class AnnotationBatchImportApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @cloud_edition_billing_resource_check('annotation') | |||
| @cloud_edition_billing_resource_check("annotation") | |||
| def post(self, app_id): | |||
| if not current_user.is_editor: | |||
| raise Forbidden() | |||
| app_id = str(app_id) | |||
| # get file from request | |||
| file = request.files['file'] | |||
| file = request.files["file"] | |||
| # check file | |||
| if 'file' not in request.files: | |||
| if "file" not in request.files: | |||
| raise NoFileUploadedError() | |||
| if len(request.files) > 1: | |||
| raise TooManyFilesError() | |||
| # check file type | |||
| if not file.filename.endswith('.csv'): | |||
| if not file.filename.endswith(".csv"): | |||
| raise ValueError("Invalid file type. Only CSV files are allowed") | |||
| return AppAnnotationService.batch_import_app_annotations(app_id, file) | |||
| @@ -220,27 +214,23 @@ class AnnotationBatchImportStatusApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @cloud_edition_billing_resource_check('annotation') | |||
| @cloud_edition_billing_resource_check("annotation") | |||
| def get(self, app_id, job_id): | |||
| if not current_user.is_editor: | |||
| raise Forbidden() | |||
| job_id = str(job_id) | |||
| indexing_cache_key = 'app_annotation_batch_import_{}'.format(str(job_id)) | |||
| indexing_cache_key = "app_annotation_batch_import_{}".format(str(job_id)) | |||
| cache_result = redis_client.get(indexing_cache_key) | |||
| if cache_result is None: | |||
| raise ValueError("The job is not exist.") | |||
| job_status = cache_result.decode() | |||
| error_msg = '' | |||
| if job_status == 'error': | |||
| indexing_error_msg_key = 'app_annotation_batch_import_error_msg_{}'.format(str(job_id)) | |||
| error_msg = "" | |||
| if job_status == "error": | |||
| indexing_error_msg_key = "app_annotation_batch_import_error_msg_{}".format(str(job_id)) | |||
| error_msg = redis_client.get(indexing_error_msg_key).decode() | |||
| return { | |||
| 'job_id': job_id, | |||
| 'job_status': job_status, | |||
| 'error_msg': error_msg | |||
| }, 200 | |||
| return {"job_id": job_id, "job_status": job_status, "error_msg": error_msg}, 200 | |||
| class AnnotationHitHistoryListApi(Resource): | |||
| @@ -251,30 +241,32 @@ class AnnotationHitHistoryListApi(Resource): | |||
| if not current_user.is_editor: | |||
| raise Forbidden() | |||
| page = request.args.get('page', default=1, type=int) | |||
| limit = request.args.get('limit', default=20, type=int) | |||
| page = request.args.get("page", default=1, type=int) | |||
| limit = request.args.get("limit", default=20, type=int) | |||
| app_id = str(app_id) | |||
| annotation_id = str(annotation_id) | |||
| annotation_hit_history_list, total = AppAnnotationService.get_annotation_hit_histories(app_id, annotation_id, | |||
| page, limit) | |||
| annotation_hit_history_list, total = AppAnnotationService.get_annotation_hit_histories( | |||
| app_id, annotation_id, page, limit | |||
| ) | |||
| response = { | |||
| 'data': marshal(annotation_hit_history_list, annotation_hit_history_fields), | |||
| 'has_more': len(annotation_hit_history_list) == limit, | |||
| 'limit': limit, | |||
| 'total': total, | |||
| 'page': page | |||
| "data": marshal(annotation_hit_history_list, annotation_hit_history_fields), | |||
| "has_more": len(annotation_hit_history_list) == limit, | |||
| "limit": limit, | |||
| "total": total, | |||
| "page": page, | |||
| } | |||
| return response | |||
| api.add_resource(AnnotationReplyActionApi, '/apps/<uuid:app_id>/annotation-reply/<string:action>') | |||
| api.add_resource(AnnotationReplyActionStatusApi, | |||
| '/apps/<uuid:app_id>/annotation-reply/<string:action>/status/<uuid:job_id>') | |||
| api.add_resource(AnnotationListApi, '/apps/<uuid:app_id>/annotations') | |||
| api.add_resource(AnnotationExportApi, '/apps/<uuid:app_id>/annotations/export') | |||
| api.add_resource(AnnotationUpdateDeleteApi, '/apps/<uuid:app_id>/annotations/<uuid:annotation_id>') | |||
| api.add_resource(AnnotationBatchImportApi, '/apps/<uuid:app_id>/annotations/batch-import') | |||
| api.add_resource(AnnotationBatchImportStatusApi, '/apps/<uuid:app_id>/annotations/batch-import-status/<uuid:job_id>') | |||
| api.add_resource(AnnotationHitHistoryListApi, '/apps/<uuid:app_id>/annotations/<uuid:annotation_id>/hit-histories') | |||
| api.add_resource(AppAnnotationSettingDetailApi, '/apps/<uuid:app_id>/annotation-setting') | |||
| api.add_resource(AppAnnotationSettingUpdateApi, '/apps/<uuid:app_id>/annotation-settings/<uuid:annotation_setting_id>') | |||
| api.add_resource(AnnotationReplyActionApi, "/apps/<uuid:app_id>/annotation-reply/<string:action>") | |||
| api.add_resource( | |||
| AnnotationReplyActionStatusApi, "/apps/<uuid:app_id>/annotation-reply/<string:action>/status/<uuid:job_id>" | |||
| ) | |||
| api.add_resource(AnnotationListApi, "/apps/<uuid:app_id>/annotations") | |||
| api.add_resource(AnnotationExportApi, "/apps/<uuid:app_id>/annotations/export") | |||
| api.add_resource(AnnotationUpdateDeleteApi, "/apps/<uuid:app_id>/annotations/<uuid:annotation_id>") | |||
| api.add_resource(AnnotationBatchImportApi, "/apps/<uuid:app_id>/annotations/batch-import") | |||
| api.add_resource(AnnotationBatchImportStatusApi, "/apps/<uuid:app_id>/annotations/batch-import-status/<uuid:job_id>") | |||
| api.add_resource(AnnotationHitHistoryListApi, "/apps/<uuid:app_id>/annotations/<uuid:annotation_id>/hit-histories") | |||
| api.add_resource(AppAnnotationSettingDetailApi, "/apps/<uuid:app_id>/annotation-setting") | |||
| api.add_resource(AppAnnotationSettingUpdateApi, "/apps/<uuid:app_id>/annotation-settings/<uuid:annotation_setting_id>") | |||
| @@ -18,27 +18,35 @@ from libs.login import login_required | |||
| from services.app_dsl_service import AppDslService | |||
| from services.app_service import AppService | |||
| ALLOW_CREATE_APP_MODES = ['chat', 'agent-chat', 'advanced-chat', 'workflow', 'completion'] | |||
| ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "completion"] | |||
| class AppListApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| def get(self): | |||
| """Get app list""" | |||
| def uuid_list(value): | |||
| try: | |||
| return [str(uuid.UUID(v)) for v in value.split(',')] | |||
| return [str(uuid.UUID(v)) for v in value.split(",")] | |||
| except ValueError: | |||
| abort(400, message="Invalid UUID format in tag_ids.") | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('page', type=inputs.int_range(1, 99999), required=False, default=1, location='args') | |||
| parser.add_argument('limit', type=inputs.int_range(1, 100), required=False, default=20, location='args') | |||
| parser.add_argument('mode', type=str, choices=['chat', 'workflow', 'agent-chat', 'channel', 'all'], default='all', location='args', required=False) | |||
| parser.add_argument('name', type=str, location='args', required=False) | |||
| parser.add_argument('tag_ids', type=uuid_list, location='args', required=False) | |||
| parser.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args") | |||
| parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args") | |||
| parser.add_argument( | |||
| "mode", | |||
| type=str, | |||
| choices=["chat", "workflow", "agent-chat", "channel", "all"], | |||
| default="all", | |||
| location="args", | |||
| required=False, | |||
| ) | |||
| parser.add_argument("name", type=str, location="args", required=False) | |||
| parser.add_argument("tag_ids", type=uuid_list, location="args", required=False) | |||
| args = parser.parse_args() | |||
| @@ -46,7 +54,7 @@ class AppListApi(Resource): | |||
| app_service = AppService() | |||
| app_pagination = app_service.get_paginate_apps(current_user.current_tenant_id, args) | |||
| if not app_pagination: | |||
| return {'data': [], 'total': 0, 'page': 1, 'limit': 20, 'has_more': False} | |||
| return {"data": [], "total": 0, "page": 1, "limit": 20, "has_more": False} | |||
| return marshal(app_pagination, app_pagination_fields) | |||
| @@ -54,23 +62,23 @@ class AppListApi(Resource): | |||
| @login_required | |||
| @account_initialization_required | |||
| @marshal_with(app_detail_fields) | |||
| @cloud_edition_billing_resource_check('apps') | |||
| @cloud_edition_billing_resource_check("apps") | |||
| def post(self): | |||
| """Create app""" | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('name', type=str, required=True, location='json') | |||
| parser.add_argument('description', type=str, location='json') | |||
| parser.add_argument('mode', type=str, choices=ALLOW_CREATE_APP_MODES, location='json') | |||
| parser.add_argument('icon_type', type=str, location='json') | |||
| parser.add_argument('icon', type=str, location='json') | |||
| parser.add_argument('icon_background', type=str, location='json') | |||
| parser.add_argument("name", type=str, required=True, location="json") | |||
| parser.add_argument("description", type=str, location="json") | |||
| parser.add_argument("mode", type=str, choices=ALLOW_CREATE_APP_MODES, location="json") | |||
| parser.add_argument("icon_type", type=str, location="json") | |||
| parser.add_argument("icon", type=str, location="json") | |||
| parser.add_argument("icon_background", type=str, location="json") | |||
| args = parser.parse_args() | |||
| # The role of the current user in the ta table must be admin, owner, or editor | |||
| if not current_user.is_editor: | |||
| raise Forbidden() | |||
| if 'mode' not in args or args['mode'] is None: | |||
| if "mode" not in args or args["mode"] is None: | |||
| raise BadRequest("mode is required") | |||
| app_service = AppService() | |||
| @@ -84,7 +92,7 @@ class AppImportApi(Resource): | |||
| @login_required | |||
| @account_initialization_required | |||
| @marshal_with(app_detail_fields_with_site) | |||
| @cloud_edition_billing_resource_check('apps') | |||
| @cloud_edition_billing_resource_check("apps") | |||
| def post(self): | |||
| """Import app""" | |||
| # The role of the current user in the ta table must be admin, owner, or editor | |||
| @@ -92,19 +100,16 @@ class AppImportApi(Resource): | |||
| raise Forbidden() | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('data', type=str, required=True, nullable=False, location='json') | |||
| parser.add_argument('name', type=str, location='json') | |||
| parser.add_argument('description', type=str, location='json') | |||
| parser.add_argument('icon_type', type=str, location='json') | |||
| parser.add_argument('icon', type=str, location='json') | |||
| parser.add_argument('icon_background', type=str, location='json') | |||
| parser.add_argument("data", type=str, required=True, nullable=False, location="json") | |||
| parser.add_argument("name", type=str, location="json") | |||
| parser.add_argument("description", type=str, location="json") | |||
| parser.add_argument("icon_type", type=str, location="json") | |||
| parser.add_argument("icon", type=str, location="json") | |||
| parser.add_argument("icon_background", type=str, location="json") | |||
| args = parser.parse_args() | |||
| app = AppDslService.import_and_create_new_app( | |||
| tenant_id=current_user.current_tenant_id, | |||
| data=args['data'], | |||
| args=args, | |||
| account=current_user | |||
| tenant_id=current_user.current_tenant_id, data=args["data"], args=args, account=current_user | |||
| ) | |||
| return app, 201 | |||
| @@ -115,7 +120,7 @@ class AppImportFromUrlApi(Resource): | |||
| @login_required | |||
| @account_initialization_required | |||
| @marshal_with(app_detail_fields_with_site) | |||
| @cloud_edition_billing_resource_check('apps') | |||
| @cloud_edition_billing_resource_check("apps") | |||
| def post(self): | |||
| """Import app from url""" | |||
| # The role of the current user in the ta table must be admin, owner, or editor | |||
| @@ -123,25 +128,21 @@ class AppImportFromUrlApi(Resource): | |||
| raise Forbidden() | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('url', type=str, required=True, nullable=False, location='json') | |||
| parser.add_argument('name', type=str, location='json') | |||
| parser.add_argument('description', type=str, location='json') | |||
| parser.add_argument('icon', type=str, location='json') | |||
| parser.add_argument('icon_background', type=str, location='json') | |||
| parser.add_argument("url", type=str, required=True, nullable=False, location="json") | |||
| parser.add_argument("name", type=str, location="json") | |||
| parser.add_argument("description", type=str, location="json") | |||
| parser.add_argument("icon", type=str, location="json") | |||
| parser.add_argument("icon_background", type=str, location="json") | |||
| args = parser.parse_args() | |||
| app = AppDslService.import_and_create_new_app_from_url( | |||
| tenant_id=current_user.current_tenant_id, | |||
| url=args['url'], | |||
| args=args, | |||
| account=current_user | |||
| tenant_id=current_user.current_tenant_id, url=args["url"], args=args, account=current_user | |||
| ) | |||
| return app, 201 | |||
| class AppApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @@ -165,14 +166,14 @@ class AppApi(Resource): | |||
| # The role of the current user in the ta table must be admin, owner, or editor | |||
| if not current_user.is_editor: | |||
| raise Forbidden() | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('name', type=str, required=True, nullable=False, location='json') | |||
| parser.add_argument('description', type=str, location='json') | |||
| parser.add_argument('icon_type', type=str, location='json') | |||
| parser.add_argument('icon', type=str, location='json') | |||
| parser.add_argument('icon_background', type=str, location='json') | |||
| parser.add_argument('max_active_requests', type=int, location='json') | |||
| parser.add_argument("name", type=str, required=True, nullable=False, location="json") | |||
| parser.add_argument("description", type=str, location="json") | |||
| parser.add_argument("icon_type", type=str, location="json") | |||
| parser.add_argument("icon", type=str, location="json") | |||
| parser.add_argument("icon_background", type=str, location="json") | |||
| parser.add_argument("max_active_requests", type=int, location="json") | |||
| args = parser.parse_args() | |||
| app_service = AppService() | |||
| @@ -193,7 +194,7 @@ class AppApi(Resource): | |||
| app_service = AppService() | |||
| app_service.delete_app(app_model) | |||
| return {'result': 'success'}, 204 | |||
| return {"result": "success"}, 204 | |||
| class AppCopyApi(Resource): | |||
| @@ -209,19 +210,16 @@ class AppCopyApi(Resource): | |||
| raise Forbidden() | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('name', type=str, location='json') | |||
| parser.add_argument('description', type=str, location='json') | |||
| parser.add_argument('icon_type', type=str, location='json') | |||
| parser.add_argument('icon', type=str, location='json') | |||
| parser.add_argument('icon_background', type=str, location='json') | |||
| parser.add_argument("name", type=str, location="json") | |||
| parser.add_argument("description", type=str, location="json") | |||
| parser.add_argument("icon_type", type=str, location="json") | |||
| parser.add_argument("icon", type=str, location="json") | |||
| parser.add_argument("icon_background", type=str, location="json") | |||
| args = parser.parse_args() | |||
| data = AppDslService.export_dsl(app_model=app_model, include_secret=True) | |||
| app = AppDslService.import_and_create_new_app( | |||
| tenant_id=current_user.current_tenant_id, | |||
| data=data, | |||
| args=args, | |||
| account=current_user | |||
| tenant_id=current_user.current_tenant_id, data=data, args=args, account=current_user | |||
| ) | |||
| return app, 201 | |||
| @@ -240,12 +238,10 @@ class AppExportApi(Resource): | |||
| # Add include_secret params | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('include_secret', type=inputs.boolean, default=False, location='args') | |||
| parser.add_argument("include_secret", type=inputs.boolean, default=False, location="args") | |||
| args = parser.parse_args() | |||
| return { | |||
| "data": AppDslService.export_dsl(app_model=app_model, include_secret=args['include_secret']) | |||
| } | |||
| return {"data": AppDslService.export_dsl(app_model=app_model, include_secret=args["include_secret"])} | |||
| class AppNameApi(Resource): | |||
| @@ -258,13 +254,13 @@ class AppNameApi(Resource): | |||
| # The role of the current user in the ta table must be admin, owner, or editor | |||
| if not current_user.is_editor: | |||
| raise Forbidden() | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('name', type=str, required=True, location='json') | |||
| parser.add_argument("name", type=str, required=True, location="json") | |||
| args = parser.parse_args() | |||
| app_service = AppService() | |||
| app_model = app_service.update_app_name(app_model, args.get('name')) | |||
| app_model = app_service.update_app_name(app_model, args.get("name")) | |||
| return app_model | |||
| @@ -279,14 +275,14 @@ class AppIconApi(Resource): | |||
| # The role of the current user in the ta table must be admin, owner, or editor | |||
| if not current_user.is_editor: | |||
| raise Forbidden() | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('icon', type=str, location='json') | |||
| parser.add_argument('icon_background', type=str, location='json') | |||
| parser.add_argument("icon", type=str, location="json") | |||
| parser.add_argument("icon_background", type=str, location="json") | |||
| args = parser.parse_args() | |||
| app_service = AppService() | |||
| app_model = app_service.update_app_icon(app_model, args.get('icon'), args.get('icon_background')) | |||
| app_model = app_service.update_app_icon(app_model, args.get("icon"), args.get("icon_background")) | |||
| return app_model | |||
| @@ -301,13 +297,13 @@ class AppSiteStatus(Resource): | |||
| # The role of the current user in the ta table must be admin, owner, or editor | |||
| if not current_user.is_editor: | |||
| raise Forbidden() | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('enable_site', type=bool, required=True, location='json') | |||
| parser.add_argument("enable_site", type=bool, required=True, location="json") | |||
| args = parser.parse_args() | |||
| app_service = AppService() | |||
| app_model = app_service.update_app_site_status(app_model, args.get('enable_site')) | |||
| app_model = app_service.update_app_site_status(app_model, args.get("enable_site")) | |||
| return app_model | |||
| @@ -322,13 +318,13 @@ class AppApiStatus(Resource): | |||
| # The role of the current user in the ta table must be admin or owner | |||
| if not current_user.is_admin_or_owner: | |||
| raise Forbidden() | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('enable_api', type=bool, required=True, location='json') | |||
| parser.add_argument("enable_api", type=bool, required=True, location="json") | |||
| args = parser.parse_args() | |||
| app_service = AppService() | |||
| app_model = app_service.update_app_api_status(app_model, args.get('enable_api')) | |||
| app_model = app_service.update_app_api_status(app_model, args.get("enable_api")) | |||
| return app_model | |||
| @@ -339,9 +335,7 @@ class AppTraceApi(Resource): | |||
| @account_initialization_required | |||
| def get(self, app_id): | |||
| """Get app trace""" | |||
| app_trace_config = OpsTraceManager.get_app_tracing_config( | |||
| app_id=app_id | |||
| ) | |||
| app_trace_config = OpsTraceManager.get_app_tracing_config(app_id=app_id) | |||
| return app_trace_config | |||
| @@ -353,27 +347,27 @@ class AppTraceApi(Resource): | |||
| if not current_user.is_admin_or_owner: | |||
| raise Forbidden() | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('enabled', type=bool, required=True, location='json') | |||
| parser.add_argument('tracing_provider', type=str, required=True, location='json') | |||
| parser.add_argument("enabled", type=bool, required=True, location="json") | |||
| parser.add_argument("tracing_provider", type=str, required=True, location="json") | |||
| args = parser.parse_args() | |||
| OpsTraceManager.update_app_tracing_config( | |||
| app_id=app_id, | |||
| enabled=args['enabled'], | |||
| tracing_provider=args['tracing_provider'], | |||
| enabled=args["enabled"], | |||
| tracing_provider=args["tracing_provider"], | |||
| ) | |||
| return {"result": "success"} | |||
| api.add_resource(AppListApi, '/apps') | |||
| api.add_resource(AppImportApi, '/apps/import') | |||
| api.add_resource(AppImportFromUrlApi, '/apps/import/url') | |||
| api.add_resource(AppApi, '/apps/<uuid:app_id>') | |||
| api.add_resource(AppCopyApi, '/apps/<uuid:app_id>/copy') | |||
| api.add_resource(AppExportApi, '/apps/<uuid:app_id>/export') | |||
| api.add_resource(AppNameApi, '/apps/<uuid:app_id>/name') | |||
| api.add_resource(AppIconApi, '/apps/<uuid:app_id>/icon') | |||
| api.add_resource(AppSiteStatus, '/apps/<uuid:app_id>/site-enable') | |||
| api.add_resource(AppApiStatus, '/apps/<uuid:app_id>/api-enable') | |||
| api.add_resource(AppTraceApi, '/apps/<uuid:app_id>/trace') | |||
| api.add_resource(AppListApi, "/apps") | |||
| api.add_resource(AppImportApi, "/apps/import") | |||
| api.add_resource(AppImportFromUrlApi, "/apps/import/url") | |||
| api.add_resource(AppApi, "/apps/<uuid:app_id>") | |||
| api.add_resource(AppCopyApi, "/apps/<uuid:app_id>/copy") | |||
| api.add_resource(AppExportApi, "/apps/<uuid:app_id>/export") | |||
| api.add_resource(AppNameApi, "/apps/<uuid:app_id>/name") | |||
| api.add_resource(AppIconApi, "/apps/<uuid:app_id>/icon") | |||
| api.add_resource(AppSiteStatus, "/apps/<uuid:app_id>/site-enable") | |||
| api.add_resource(AppApiStatus, "/apps/<uuid:app_id>/api-enable") | |||
| api.add_resource(AppTraceApi, "/apps/<uuid:app_id>/trace") | |||
| @@ -39,7 +39,7 @@ class ChatMessageAudioApi(Resource): | |||
| @account_initialization_required | |||
| @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]) | |||
| def post(self, app_model): | |||
| file = request.files['file'] | |||
| file = request.files["file"] | |||
| try: | |||
| response = AudioService.transcript_asr( | |||
| @@ -85,31 +85,31 @@ class ChatMessageTextApi(Resource): | |||
| try: | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('message_id', type=str, location='json') | |||
| parser.add_argument('text', type=str, location='json') | |||
| parser.add_argument('voice', type=str, location='json') | |||
| parser.add_argument('streaming', type=bool, location='json') | |||
| parser.add_argument("message_id", type=str, location="json") | |||
| parser.add_argument("text", type=str, location="json") | |||
| parser.add_argument("voice", type=str, location="json") | |||
| parser.add_argument("streaming", type=bool, location="json") | |||
| args = parser.parse_args() | |||
| message_id = args.get('message_id', None) | |||
| text = args.get('text', None) | |||
| if (app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value] | |||
| and app_model.workflow | |||
| and app_model.workflow.features_dict): | |||
| text_to_speech = app_model.workflow.features_dict.get('text_to_speech') | |||
| voice = args.get('voice') if args.get('voice') else text_to_speech.get('voice') | |||
| message_id = args.get("message_id", None) | |||
| text = args.get("text", None) | |||
| if ( | |||
| app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value] | |||
| and app_model.workflow | |||
| and app_model.workflow.features_dict | |||
| ): | |||
| text_to_speech = app_model.workflow.features_dict.get("text_to_speech") | |||
| voice = args.get("voice") if args.get("voice") else text_to_speech.get("voice") | |||
| else: | |||
| try: | |||
| voice = args.get('voice') if args.get('voice') else app_model.app_model_config.text_to_speech_dict.get( | |||
| 'voice') | |||
| voice = ( | |||
| args.get("voice") | |||
| if args.get("voice") | |||
| else app_model.app_model_config.text_to_speech_dict.get("voice") | |||
| ) | |||
| except Exception: | |||
| voice = None | |||
| response = AudioService.transcript_tts( | |||
| app_model=app_model, | |||
| text=text, | |||
| message_id=message_id, | |||
| voice=voice | |||
| ) | |||
| response = AudioService.transcript_tts(app_model=app_model, text=text, message_id=message_id, voice=voice) | |||
| return response | |||
| except services.errors.app_model_config.AppModelConfigBrokenError: | |||
| logging.exception("App model config broken.") | |||
| @@ -145,12 +145,12 @@ class TextModesApi(Resource): | |||
| def get(self, app_model): | |||
| try: | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('language', type=str, required=True, location='args') | |||
| parser.add_argument("language", type=str, required=True, location="args") | |||
| args = parser.parse_args() | |||
| response = AudioService.transcript_tts_voices( | |||
| tenant_id=app_model.tenant_id, | |||
| language=args['language'], | |||
| language=args["language"], | |||
| ) | |||
| return response | |||
| @@ -179,6 +179,6 @@ class TextModesApi(Resource): | |||
| raise InternalServerError() | |||
| api.add_resource(ChatMessageAudioApi, '/apps/<uuid:app_id>/audio-to-text') | |||
| api.add_resource(ChatMessageTextApi, '/apps/<uuid:app_id>/text-to-audio') | |||
| api.add_resource(TextModesApi, '/apps/<uuid:app_id>/text-to-audio/voices') | |||
| api.add_resource(ChatMessageAudioApi, "/apps/<uuid:app_id>/audio-to-text") | |||
| api.add_resource(ChatMessageTextApi, "/apps/<uuid:app_id>/text-to-audio") | |||
| api.add_resource(TextModesApi, "/apps/<uuid:app_id>/text-to-audio/voices") | |||
| @@ -35,33 +35,28 @@ from services.app_generate_service import AppGenerateService | |||
| # define completion message api for user | |||
| class CompletionMessageApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @get_app_model(mode=AppMode.COMPLETION) | |||
| def post(self, app_model): | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('inputs', type=dict, required=True, location='json') | |||
| parser.add_argument('query', type=str, location='json', default='') | |||
| parser.add_argument('files', type=list, required=False, location='json') | |||
| parser.add_argument('model_config', type=dict, required=True, location='json') | |||
| parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json') | |||
| parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json') | |||
| parser.add_argument("inputs", type=dict, required=True, location="json") | |||
| parser.add_argument("query", type=str, location="json", default="") | |||
| parser.add_argument("files", type=list, required=False, location="json") | |||
| parser.add_argument("model_config", type=dict, required=True, location="json") | |||
| parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json") | |||
| parser.add_argument("retriever_from", type=str, required=False, default="dev", location="json") | |||
| args = parser.parse_args() | |||
| streaming = args['response_mode'] != 'blocking' | |||
| args['auto_generate_name'] = False | |||
| streaming = args["response_mode"] != "blocking" | |||
| args["auto_generate_name"] = False | |||
| account = flask_login.current_user | |||
| try: | |||
| response = AppGenerateService.generate( | |||
| app_model=app_model, | |||
| user=account, | |||
| args=args, | |||
| invoke_from=InvokeFrom.DEBUGGER, | |||
| streaming=streaming | |||
| app_model=app_model, user=account, args=args, invoke_from=InvokeFrom.DEBUGGER, streaming=streaming | |||
| ) | |||
| return helper.compact_generate_response(response) | |||
| @@ -97,7 +92,7 @@ class CompletionMessageStopApi(Resource): | |||
| AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, account.id) | |||
| return {'result': 'success'}, 200 | |||
| return {"result": "success"}, 200 | |||
| class ChatMessageApi(Resource): | |||
| @@ -107,27 +102,23 @@ class ChatMessageApi(Resource): | |||
| @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT]) | |||
| def post(self, app_model): | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('inputs', type=dict, required=True, location='json') | |||
| parser.add_argument('query', type=str, required=True, location='json') | |||
| parser.add_argument('files', type=list, required=False, location='json') | |||
| parser.add_argument('model_config', type=dict, required=True, location='json') | |||
| parser.add_argument('conversation_id', type=uuid_value, location='json') | |||
| parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json') | |||
| parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json') | |||
| parser.add_argument("inputs", type=dict, required=True, location="json") | |||
| parser.add_argument("query", type=str, required=True, location="json") | |||
| parser.add_argument("files", type=list, required=False, location="json") | |||
| parser.add_argument("model_config", type=dict, required=True, location="json") | |||
| parser.add_argument("conversation_id", type=uuid_value, location="json") | |||
| parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json") | |||
| parser.add_argument("retriever_from", type=str, required=False, default="dev", location="json") | |||
| args = parser.parse_args() | |||
| streaming = args['response_mode'] != 'blocking' | |||
| args['auto_generate_name'] = False | |||
| streaming = args["response_mode"] != "blocking" | |||
| args["auto_generate_name"] = False | |||
| account = flask_login.current_user | |||
| try: | |||
| response = AppGenerateService.generate( | |||
| app_model=app_model, | |||
| user=account, | |||
| args=args, | |||
| invoke_from=InvokeFrom.DEBUGGER, | |||
| streaming=streaming | |||
| app_model=app_model, user=account, args=args, invoke_from=InvokeFrom.DEBUGGER, streaming=streaming | |||
| ) | |||
| return helper.compact_generate_response(response) | |||
| @@ -163,10 +154,10 @@ class ChatMessageStopApi(Resource): | |||
| AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, account.id) | |||
| return {'result': 'success'}, 200 | |||
| return {"result": "success"}, 200 | |||
| api.add_resource(CompletionMessageApi, '/apps/<uuid:app_id>/completion-messages') | |||
| api.add_resource(CompletionMessageStopApi, '/apps/<uuid:app_id>/completion-messages/<string:task_id>/stop') | |||
| api.add_resource(ChatMessageApi, '/apps/<uuid:app_id>/chat-messages') | |||
| api.add_resource(ChatMessageStopApi, '/apps/<uuid:app_id>/chat-messages/<string:task_id>/stop') | |||
| api.add_resource(CompletionMessageApi, "/apps/<uuid:app_id>/completion-messages") | |||
| api.add_resource(CompletionMessageStopApi, "/apps/<uuid:app_id>/completion-messages/<string:task_id>/stop") | |||
| api.add_resource(ChatMessageApi, "/apps/<uuid:app_id>/chat-messages") | |||
| api.add_resource(ChatMessageStopApi, "/apps/<uuid:app_id>/chat-messages/<string:task_id>/stop") | |||
| @@ -26,7 +26,6 @@ from models.model import AppMode, Conversation, EndUser, Message, MessageAnnotat | |||
| class CompletionConversationApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @@ -36,24 +35,23 @@ class CompletionConversationApi(Resource): | |||
| if not current_user.is_editor: | |||
| raise Forbidden() | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('keyword', type=str, location='args') | |||
| parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args') | |||
| parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args') | |||
| parser.add_argument('annotation_status', type=str, | |||
| choices=['annotated', 'not_annotated', 'all'], default='all', location='args') | |||
| parser.add_argument('page', type=int_range(1, 99999), default=1, location='args') | |||
| parser.add_argument('limit', type=int_range(1, 100), default=20, location='args') | |||
| parser.add_argument("keyword", type=str, location="args") | |||
| parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") | |||
| parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") | |||
| parser.add_argument( | |||
| "annotation_status", type=str, choices=["annotated", "not_annotated", "all"], default="all", location="args" | |||
| ) | |||
| parser.add_argument("page", type=int_range(1, 99999), default=1, location="args") | |||
| parser.add_argument("limit", type=int_range(1, 100), default=20, location="args") | |||
| args = parser.parse_args() | |||
| query = db.select(Conversation).where(Conversation.app_id == app_model.id, Conversation.mode == 'completion') | |||
| query = db.select(Conversation).where(Conversation.app_id == app_model.id, Conversation.mode == "completion") | |||
| if args['keyword']: | |||
| query = query.join( | |||
| Message, Message.conversation_id == Conversation.id | |||
| ).filter( | |||
| if args["keyword"]: | |||
| query = query.join(Message, Message.conversation_id == Conversation.id).filter( | |||
| or_( | |||
| Message.query.ilike('%{}%'.format(args['keyword'])), | |||
| Message.answer.ilike('%{}%'.format(args['keyword'])) | |||
| Message.query.ilike("%{}%".format(args["keyword"])), | |||
| Message.answer.ilike("%{}%".format(args["keyword"])), | |||
| ) | |||
| ) | |||
| @@ -61,8 +59,8 @@ class CompletionConversationApi(Resource): | |||
| timezone = pytz.timezone(account.timezone) | |||
| utc_timezone = pytz.utc | |||
| if args['start']: | |||
| start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M') | |||
| if args["start"]: | |||
| start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M") | |||
| start_datetime = start_datetime.replace(second=0) | |||
| start_datetime_timezone = timezone.localize(start_datetime) | |||
| @@ -70,8 +68,8 @@ class CompletionConversationApi(Resource): | |||
| query = query.where(Conversation.created_at >= start_datetime_utc) | |||
| if args['end']: | |||
| end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M') | |||
| if args["end"]: | |||
| end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M") | |||
| end_datetime = end_datetime.replace(second=59) | |||
| end_datetime_timezone = timezone.localize(end_datetime) | |||
| @@ -79,29 +77,25 @@ class CompletionConversationApi(Resource): | |||
| query = query.where(Conversation.created_at < end_datetime_utc) | |||
| if args['annotation_status'] == "annotated": | |||
| if args["annotation_status"] == "annotated": | |||
| query = query.options(joinedload(Conversation.message_annotations)).join( | |||
| MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id | |||
| ) | |||
| elif args['annotation_status'] == "not_annotated": | |||
| query = query.outerjoin( | |||
| MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id | |||
| ).group_by(Conversation.id).having(func.count(MessageAnnotation.id) == 0) | |||
| elif args["annotation_status"] == "not_annotated": | |||
| query = ( | |||
| query.outerjoin(MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id) | |||
| .group_by(Conversation.id) | |||
| .having(func.count(MessageAnnotation.id) == 0) | |||
| ) | |||
| query = query.order_by(Conversation.created_at.desc()) | |||
| conversations = db.paginate( | |||
| query, | |||
| page=args['page'], | |||
| per_page=args['limit'], | |||
| error_out=False | |||
| ) | |||
| conversations = db.paginate(query, page=args["page"], per_page=args["limit"], error_out=False) | |||
| return conversations | |||
| class CompletionConversationDetailApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @@ -123,8 +117,11 @@ class CompletionConversationDetailApi(Resource): | |||
| raise Forbidden() | |||
| conversation_id = str(conversation_id) | |||
| conversation = db.session.query(Conversation) \ | |||
| .filter(Conversation.id == conversation_id, Conversation.app_id == app_model.id).first() | |||
| conversation = ( | |||
| db.session.query(Conversation) | |||
| .filter(Conversation.id == conversation_id, Conversation.app_id == app_model.id) | |||
| .first() | |||
| ) | |||
| if not conversation: | |||
| raise NotFound("Conversation Not Exists.") | |||
| @@ -132,11 +129,10 @@ class CompletionConversationDetailApi(Resource): | |||
| conversation.is_deleted = True | |||
| db.session.commit() | |||
| return {'result': 'success'}, 204 | |||
| return {"result": "success"}, 204 | |||
| class ChatConversationApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @@ -146,22 +142,28 @@ class ChatConversationApi(Resource): | |||
| if not current_user.is_editor: | |||
| raise Forbidden() | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('keyword', type=str, location='args') | |||
| parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args') | |||
| parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args') | |||
| parser.add_argument('annotation_status', type=str, | |||
| choices=['annotated', 'not_annotated', 'all'], default='all', location='args') | |||
| parser.add_argument('message_count_gte', type=int_range(1, 99999), required=False, location='args') | |||
| parser.add_argument('page', type=int_range(1, 99999), required=False, default=1, location='args') | |||
| parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args') | |||
| parser.add_argument('sort_by', type=str, choices=['created_at', '-created_at', 'updated_at', '-updated_at'], | |||
| required=False, default='-updated_at', location='args') | |||
| parser.add_argument("keyword", type=str, location="args") | |||
| parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") | |||
| parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") | |||
| parser.add_argument( | |||
| "annotation_status", type=str, choices=["annotated", "not_annotated", "all"], default="all", location="args" | |||
| ) | |||
| parser.add_argument("message_count_gte", type=int_range(1, 99999), required=False, location="args") | |||
| parser.add_argument("page", type=int_range(1, 99999), required=False, default=1, location="args") | |||
| parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") | |||
| parser.add_argument( | |||
| "sort_by", | |||
| type=str, | |||
| choices=["created_at", "-created_at", "updated_at", "-updated_at"], | |||
| required=False, | |||
| default="-updated_at", | |||
| location="args", | |||
| ) | |||
| args = parser.parse_args() | |||
| subquery = ( | |||
| db.session.query( | |||
| Conversation.id.label('conversation_id'), | |||
| EndUser.session_id.label('from_end_user_session_id') | |||
| Conversation.id.label("conversation_id"), EndUser.session_id.label("from_end_user_session_id") | |||
| ) | |||
| .outerjoin(EndUser, Conversation.from_end_user_id == EndUser.id) | |||
| .subquery() | |||
| @@ -169,28 +171,31 @@ class ChatConversationApi(Resource): | |||
| query = db.select(Conversation).where(Conversation.app_id == app_model.id) | |||
| if args['keyword']: | |||
| keyword_filter = '%{}%'.format(args['keyword']) | |||
| query = query.join( | |||
| Message, Message.conversation_id == Conversation.id, | |||
| ).join( | |||
| subquery, subquery.c.conversation_id == Conversation.id | |||
| ).filter( | |||
| or_( | |||
| Message.query.ilike(keyword_filter), | |||
| Message.answer.ilike(keyword_filter), | |||
| Conversation.name.ilike(keyword_filter), | |||
| Conversation.introduction.ilike(keyword_filter), | |||
| subquery.c.from_end_user_session_id.ilike(keyword_filter) | |||
| ), | |||
| if args["keyword"]: | |||
| keyword_filter = "%{}%".format(args["keyword"]) | |||
| query = ( | |||
| query.join( | |||
| Message, | |||
| Message.conversation_id == Conversation.id, | |||
| ) | |||
| .join(subquery, subquery.c.conversation_id == Conversation.id) | |||
| .filter( | |||
| or_( | |||
| Message.query.ilike(keyword_filter), | |||
| Message.answer.ilike(keyword_filter), | |||
| Conversation.name.ilike(keyword_filter), | |||
| Conversation.introduction.ilike(keyword_filter), | |||
| subquery.c.from_end_user_session_id.ilike(keyword_filter), | |||
| ), | |||
| ) | |||
| ) | |||
| account = current_user | |||
| timezone = pytz.timezone(account.timezone) | |||
| utc_timezone = pytz.utc | |||
| if args['start']: | |||
| start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M') | |||
| if args["start"]: | |||
| start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M") | |||
| start_datetime = start_datetime.replace(second=0) | |||
| start_datetime_timezone = timezone.localize(start_datetime) | |||
| @@ -198,8 +203,8 @@ class ChatConversationApi(Resource): | |||
| query = query.where(Conversation.created_at >= start_datetime_utc) | |||
| if args['end']: | |||
| end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M') | |||
| if args["end"]: | |||
| end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M") | |||
| end_datetime = end_datetime.replace(second=59) | |||
| end_datetime_timezone = timezone.localize(end_datetime) | |||
| @@ -207,50 +212,46 @@ class ChatConversationApi(Resource): | |||
| query = query.where(Conversation.created_at < end_datetime_utc) | |||
| if args['annotation_status'] == "annotated": | |||
| if args["annotation_status"] == "annotated": | |||
| query = query.options(joinedload(Conversation.message_annotations)).join( | |||
| MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id | |||
| ) | |||
| elif args['annotation_status'] == "not_annotated": | |||
| query = query.outerjoin( | |||
| MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id | |||
| ).group_by(Conversation.id).having(func.count(MessageAnnotation.id) == 0) | |||
| elif args["annotation_status"] == "not_annotated": | |||
| query = ( | |||
| query.outerjoin(MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id) | |||
| .group_by(Conversation.id) | |||
| .having(func.count(MessageAnnotation.id) == 0) | |||
| ) | |||
| if args['message_count_gte'] and args['message_count_gte'] >= 1: | |||
| if args["message_count_gte"] and args["message_count_gte"] >= 1: | |||
| query = ( | |||
| query.options(joinedload(Conversation.messages)) | |||
| .join(Message, Message.conversation_id == Conversation.id) | |||
| .group_by(Conversation.id) | |||
| .having(func.count(Message.id) >= args['message_count_gte']) | |||
| .having(func.count(Message.id) >= args["message_count_gte"]) | |||
| ) | |||
| if app_model.mode == AppMode.ADVANCED_CHAT.value: | |||
| query = query.where(Conversation.invoke_from != InvokeFrom.DEBUGGER.value) | |||
| match args['sort_by']: | |||
| case 'created_at': | |||
| match args["sort_by"]: | |||
| case "created_at": | |||
| query = query.order_by(Conversation.created_at.asc()) | |||
| case '-created_at': | |||
| case "-created_at": | |||
| query = query.order_by(Conversation.created_at.desc()) | |||
| case 'updated_at': | |||
| case "updated_at": | |||
| query = query.order_by(Conversation.updated_at.asc()) | |||
| case '-updated_at': | |||
| case "-updated_at": | |||
| query = query.order_by(Conversation.updated_at.desc()) | |||
| case _: | |||
| query = query.order_by(Conversation.created_at.desc()) | |||
| conversations = db.paginate( | |||
| query, | |||
| page=args['page'], | |||
| per_page=args['limit'], | |||
| error_out=False | |||
| ) | |||
| conversations = db.paginate(query, page=args["page"], per_page=args["limit"], error_out=False) | |||
| return conversations | |||
| class ChatConversationDetailApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @@ -272,8 +273,11 @@ class ChatConversationDetailApi(Resource): | |||
| raise Forbidden() | |||
| conversation_id = str(conversation_id) | |||
| conversation = db.session.query(Conversation) \ | |||
| .filter(Conversation.id == conversation_id, Conversation.app_id == app_model.id).first() | |||
| conversation = ( | |||
| db.session.query(Conversation) | |||
| .filter(Conversation.id == conversation_id, Conversation.app_id == app_model.id) | |||
| .first() | |||
| ) | |||
| if not conversation: | |||
| raise NotFound("Conversation Not Exists.") | |||
| @@ -281,18 +285,21 @@ class ChatConversationDetailApi(Resource): | |||
| conversation.is_deleted = True | |||
| db.session.commit() | |||
| return {'result': 'success'}, 204 | |||
| return {"result": "success"}, 204 | |||
| api.add_resource(CompletionConversationApi, '/apps/<uuid:app_id>/completion-conversations') | |||
| api.add_resource(CompletionConversationDetailApi, '/apps/<uuid:app_id>/completion-conversations/<uuid:conversation_id>') | |||
| api.add_resource(ChatConversationApi, '/apps/<uuid:app_id>/chat-conversations') | |||
| api.add_resource(ChatConversationDetailApi, '/apps/<uuid:app_id>/chat-conversations/<uuid:conversation_id>') | |||
| api.add_resource(CompletionConversationApi, "/apps/<uuid:app_id>/completion-conversations") | |||
| api.add_resource(CompletionConversationDetailApi, "/apps/<uuid:app_id>/completion-conversations/<uuid:conversation_id>") | |||
| api.add_resource(ChatConversationApi, "/apps/<uuid:app_id>/chat-conversations") | |||
| api.add_resource(ChatConversationDetailApi, "/apps/<uuid:app_id>/chat-conversations/<uuid:conversation_id>") | |||
| def _get_conversation(app_model, conversation_id): | |||
| conversation = db.session.query(Conversation) \ | |||
| .filter(Conversation.id == conversation_id, Conversation.app_id == app_model.id).first() | |||
| conversation = ( | |||
| db.session.query(Conversation) | |||
| .filter(Conversation.id == conversation_id, Conversation.app_id == app_model.id) | |||
| .first() | |||
| ) | |||
| if not conversation: | |||
| raise NotFound("Conversation Not Exists.") | |||
| @@ -21,7 +21,7 @@ class ConversationVariablesApi(Resource): | |||
| @marshal_with(paginated_conversation_variable_fields) | |||
| def get(self, app_model): | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('conversation_id', type=str, location='args') | |||
| parser.add_argument("conversation_id", type=str, location="args") | |||
| args = parser.parse_args() | |||
| stmt = ( | |||
| @@ -29,10 +29,10 @@ class ConversationVariablesApi(Resource): | |||
| .where(ConversationVariable.app_id == app_model.id) | |||
| .order_by(ConversationVariable.created_at) | |||
| ) | |||
| if args['conversation_id']: | |||
| stmt = stmt.where(ConversationVariable.conversation_id == args['conversation_id']) | |||
| if args["conversation_id"]: | |||
| stmt = stmt.where(ConversationVariable.conversation_id == args["conversation_id"]) | |||
| else: | |||
| raise ValueError('conversation_id is required') | |||
| raise ValueError("conversation_id is required") | |||
| # NOTE: This is a temporary solution to avoid performance issues. | |||
| page = 1 | |||
| @@ -43,14 +43,14 @@ class ConversationVariablesApi(Resource): | |||
| rows = session.scalars(stmt).all() | |||
| return { | |||
| 'page': page, | |||
| 'limit': page_size, | |||
| 'total': len(rows), | |||
| 'has_more': False, | |||
| 'data': [ | |||
| "page": page, | |||
| "limit": page_size, | |||
| "total": len(rows), | |||
| "has_more": False, | |||
| "data": [ | |||
| { | |||
| 'created_at': row.created_at, | |||
| 'updated_at': row.updated_at, | |||
| "created_at": row.created_at, | |||
| "updated_at": row.updated_at, | |||
| **row.to_variable().model_dump(), | |||
| } | |||
| for row in rows | |||
| @@ -58,4 +58,4 @@ class ConversationVariablesApi(Resource): | |||
| } | |||
| api.add_resource(ConversationVariablesApi, '/apps/<uuid:app_id>/conversation-variables') | |||
| api.add_resource(ConversationVariablesApi, "/apps/<uuid:app_id>/conversation-variables") | |||
| @@ -2,116 +2,120 @@ from libs.exception import BaseHTTPException | |||
| class AppNotFoundError(BaseHTTPException): | |||
| error_code = 'app_not_found' | |||
| error_code = "app_not_found" | |||
| description = "App not found." | |||
| code = 404 | |||
| class ProviderNotInitializeError(BaseHTTPException): | |||
| error_code = 'provider_not_initialize' | |||
| description = "No valid model provider credentials found. " \ | |||
| "Please go to Settings -> Model Provider to complete your provider credentials." | |||
| error_code = "provider_not_initialize" | |||
| description = ( | |||
| "No valid model provider credentials found. " | |||
| "Please go to Settings -> Model Provider to complete your provider credentials." | |||
| ) | |||
| code = 400 | |||
| class ProviderQuotaExceededError(BaseHTTPException): | |||
| error_code = 'provider_quota_exceeded' | |||
| description = "Your quota for Dify Hosted Model Provider has been exhausted. " \ | |||
| "Please go to Settings -> Model Provider to complete your own provider credentials." | |||
| error_code = "provider_quota_exceeded" | |||
| description = ( | |||
| "Your quota for Dify Hosted Model Provider has been exhausted. " | |||
| "Please go to Settings -> Model Provider to complete your own provider credentials." | |||
| ) | |||
| code = 400 | |||
| class ProviderModelCurrentlyNotSupportError(BaseHTTPException): | |||
| error_code = 'model_currently_not_support' | |||
| error_code = "model_currently_not_support" | |||
| description = "Dify Hosted OpenAI trial currently not support the GPT-4 model." | |||
| code = 400 | |||
| class ConversationCompletedError(BaseHTTPException): | |||
| error_code = 'conversation_completed' | |||
| error_code = "conversation_completed" | |||
| description = "The conversation has ended. Please start a new conversation." | |||
| code = 400 | |||
| class AppUnavailableError(BaseHTTPException): | |||
| error_code = 'app_unavailable' | |||
| error_code = "app_unavailable" | |||
| description = "App unavailable, please check your app configurations." | |||
| code = 400 | |||
| class CompletionRequestError(BaseHTTPException): | |||
| error_code = 'completion_request_error' | |||
| error_code = "completion_request_error" | |||
| description = "Completion request failed." | |||
| code = 400 | |||
| class AppMoreLikeThisDisabledError(BaseHTTPException): | |||
| error_code = 'app_more_like_this_disabled' | |||
| error_code = "app_more_like_this_disabled" | |||
| description = "The 'More like this' feature is disabled. Please refresh your page." | |||
| code = 403 | |||
| class NoAudioUploadedError(BaseHTTPException): | |||
| error_code = 'no_audio_uploaded' | |||
| error_code = "no_audio_uploaded" | |||
| description = "Please upload your audio." | |||
| code = 400 | |||
| class AudioTooLargeError(BaseHTTPException): | |||
| error_code = 'audio_too_large' | |||
| error_code = "audio_too_large" | |||
| description = "Audio size exceeded. {message}" | |||
| code = 413 | |||
| class UnsupportedAudioTypeError(BaseHTTPException): | |||
| error_code = 'unsupported_audio_type' | |||
| error_code = "unsupported_audio_type" | |||
| description = "Audio type not allowed." | |||
| code = 415 | |||
| class ProviderNotSupportSpeechToTextError(BaseHTTPException): | |||
| error_code = 'provider_not_support_speech_to_text' | |||
| error_code = "provider_not_support_speech_to_text" | |||
| description = "Provider not support speech to text." | |||
| code = 400 | |||
| class NoFileUploadedError(BaseHTTPException): | |||
| error_code = 'no_file_uploaded' | |||
| error_code = "no_file_uploaded" | |||
| description = "Please upload your file." | |||
| code = 400 | |||
| class TooManyFilesError(BaseHTTPException): | |||
| error_code = 'too_many_files' | |||
| error_code = "too_many_files" | |||
| description = "Only one file is allowed." | |||
| code = 400 | |||
| class DraftWorkflowNotExist(BaseHTTPException): | |||
| error_code = 'draft_workflow_not_exist' | |||
| error_code = "draft_workflow_not_exist" | |||
| description = "Draft workflow need to be initialized." | |||
| code = 400 | |||
| class DraftWorkflowNotSync(BaseHTTPException): | |||
| error_code = 'draft_workflow_not_sync' | |||
| error_code = "draft_workflow_not_sync" | |||
| description = "Workflow graph might have been modified, please refresh and resubmit." | |||
| code = 400 | |||
| class TracingConfigNotExist(BaseHTTPException): | |||
| error_code = 'trace_config_not_exist' | |||
| error_code = "trace_config_not_exist" | |||
| description = "Trace config not exist." | |||
| code = 400 | |||
| class TracingConfigIsExist(BaseHTTPException): | |||
| error_code = 'trace_config_is_exist' | |||
| error_code = "trace_config_is_exist" | |||
| description = "Trace config is exist." | |||
| code = 400 | |||
| class TracingConfigCheckError(BaseHTTPException): | |||
| error_code = 'trace_config_check_error' | |||
| error_code = "trace_config_check_error" | |||
| description = "Invalid Credentials." | |||
| code = 400 | |||
| @@ -24,21 +24,21 @@ class RuleGenerateApi(Resource): | |||
| @account_initialization_required | |||
| def post(self): | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('instruction', type=str, required=True, nullable=False, location='json') | |||
| parser.add_argument('model_config', type=dict, required=True, nullable=False, location='json') | |||
| parser.add_argument('no_variable', type=bool, required=True, default=False, location='json') | |||
| parser.add_argument("instruction", type=str, required=True, nullable=False, location="json") | |||
| parser.add_argument("model_config", type=dict, required=True, nullable=False, location="json") | |||
| parser.add_argument("no_variable", type=bool, required=True, default=False, location="json") | |||
| args = parser.parse_args() | |||
| account = current_user | |||
| PROMPT_GENERATION_MAX_TOKENS = int(os.getenv('PROMPT_GENERATION_MAX_TOKENS', '512')) | |||
| PROMPT_GENERATION_MAX_TOKENS = int(os.getenv("PROMPT_GENERATION_MAX_TOKENS", "512")) | |||
| try: | |||
| rules = LLMGenerator.generate_rule_config( | |||
| tenant_id=account.current_tenant_id, | |||
| instruction=args['instruction'], | |||
| model_config=args['model_config'], | |||
| no_variable=args['no_variable'], | |||
| rule_config_max_tokens=PROMPT_GENERATION_MAX_TOKENS | |||
| instruction=args["instruction"], | |||
| model_config=args["model_config"], | |||
| no_variable=args["no_variable"], | |||
| rule_config_max_tokens=PROMPT_GENERATION_MAX_TOKENS, | |||
| ) | |||
| except ProviderTokenNotInitError as ex: | |||
| raise ProviderNotInitializeError(ex.description) | |||
| @@ -52,4 +52,4 @@ class RuleGenerateApi(Resource): | |||
| return rules | |||
| api.add_resource(RuleGenerateApi, '/rule-generate') | |||
| api.add_resource(RuleGenerateApi, "/rule-generate") | |||
| @@ -33,9 +33,9 @@ from services.message_service import MessageService | |||
| class ChatMessageListApi(Resource): | |||
| message_infinite_scroll_pagination_fields = { | |||
| 'limit': fields.Integer, | |||
| 'has_more': fields.Boolean, | |||
| 'data': fields.List(fields.Nested(message_detail_fields)) | |||
| "limit": fields.Integer, | |||
| "has_more": fields.Boolean, | |||
| "data": fields.List(fields.Nested(message_detail_fields)), | |||
| } | |||
| @setup_required | |||
| @@ -45,55 +45,69 @@ class ChatMessageListApi(Resource): | |||
| @marshal_with(message_infinite_scroll_pagination_fields) | |||
| def get(self, app_model): | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('conversation_id', required=True, type=uuid_value, location='args') | |||
| parser.add_argument('first_id', type=uuid_value, location='args') | |||
| parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args') | |||
| parser.add_argument("conversation_id", required=True, type=uuid_value, location="args") | |||
| parser.add_argument("first_id", type=uuid_value, location="args") | |||
| parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") | |||
| args = parser.parse_args() | |||
| conversation = db.session.query(Conversation).filter( | |||
| Conversation.id == args['conversation_id'], | |||
| Conversation.app_id == app_model.id | |||
| ).first() | |||
| conversation = ( | |||
| db.session.query(Conversation) | |||
| .filter(Conversation.id == args["conversation_id"], Conversation.app_id == app_model.id) | |||
| .first() | |||
| ) | |||
| if not conversation: | |||
| raise NotFound("Conversation Not Exists.") | |||
| if args['first_id']: | |||
| first_message = db.session.query(Message) \ | |||
| .filter(Message.conversation_id == conversation.id, Message.id == args['first_id']).first() | |||
| if args["first_id"]: | |||
| first_message = ( | |||
| db.session.query(Message) | |||
| .filter(Message.conversation_id == conversation.id, Message.id == args["first_id"]) | |||
| .first() | |||
| ) | |||
| if not first_message: | |||
| raise NotFound("First message not found") | |||
| history_messages = db.session.query(Message).filter( | |||
| Message.conversation_id == conversation.id, | |||
| Message.created_at < first_message.created_at, | |||
| Message.id != first_message.id | |||
| ) \ | |||
| .order_by(Message.created_at.desc()).limit(args['limit']).all() | |||
| history_messages = ( | |||
| db.session.query(Message) | |||
| .filter( | |||
| Message.conversation_id == conversation.id, | |||
| Message.created_at < first_message.created_at, | |||
| Message.id != first_message.id, | |||
| ) | |||
| .order_by(Message.created_at.desc()) | |||
| .limit(args["limit"]) | |||
| .all() | |||
| ) | |||
| else: | |||
| history_messages = db.session.query(Message).filter(Message.conversation_id == conversation.id) \ | |||
| .order_by(Message.created_at.desc()).limit(args['limit']).all() | |||
| history_messages = ( | |||
| db.session.query(Message) | |||
| .filter(Message.conversation_id == conversation.id) | |||
| .order_by(Message.created_at.desc()) | |||
| .limit(args["limit"]) | |||
| .all() | |||
| ) | |||
| has_more = False | |||
| if len(history_messages) == args['limit']: | |||
| if len(history_messages) == args["limit"]: | |||
| current_page_first_message = history_messages[-1] | |||
| rest_count = db.session.query(Message).filter( | |||
| Message.conversation_id == conversation.id, | |||
| Message.created_at < current_page_first_message.created_at, | |||
| Message.id != current_page_first_message.id | |||
| ).count() | |||
| rest_count = ( | |||
| db.session.query(Message) | |||
| .filter( | |||
| Message.conversation_id == conversation.id, | |||
| Message.created_at < current_page_first_message.created_at, | |||
| Message.id != current_page_first_message.id, | |||
| ) | |||
| .count() | |||
| ) | |||
| if rest_count > 0: | |||
| has_more = True | |||
| history_messages = list(reversed(history_messages)) | |||
| return InfiniteScrollPagination( | |||
| data=history_messages, | |||
| limit=args['limit'], | |||
| has_more=has_more | |||
| ) | |||
| return InfiniteScrollPagination(data=history_messages, limit=args["limit"], has_more=has_more) | |||
| class MessageFeedbackApi(Resource): | |||
| @@ -103,49 +117,46 @@ class MessageFeedbackApi(Resource): | |||
| @get_app_model | |||
| def post(self, app_model): | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('message_id', required=True, type=uuid_value, location='json') | |||
| parser.add_argument('rating', type=str, choices=['like', 'dislike', None], location='json') | |||
| parser.add_argument("message_id", required=True, type=uuid_value, location="json") | |||
| parser.add_argument("rating", type=str, choices=["like", "dislike", None], location="json") | |||
| args = parser.parse_args() | |||
| message_id = str(args['message_id']) | |||
| message_id = str(args["message_id"]) | |||
| message = db.session.query(Message).filter( | |||
| Message.id == message_id, | |||
| Message.app_id == app_model.id | |||
| ).first() | |||
| message = db.session.query(Message).filter(Message.id == message_id, Message.app_id == app_model.id).first() | |||
| if not message: | |||
| raise NotFound("Message Not Exists.") | |||
| feedback = message.admin_feedback | |||
| if not args['rating'] and feedback: | |||
| if not args["rating"] and feedback: | |||
| db.session.delete(feedback) | |||
| elif args['rating'] and feedback: | |||
| feedback.rating = args['rating'] | |||
| elif not args['rating'] and not feedback: | |||
| raise ValueError('rating cannot be None when feedback not exists') | |||
| elif args["rating"] and feedback: | |||
| feedback.rating = args["rating"] | |||
| elif not args["rating"] and not feedback: | |||
| raise ValueError("rating cannot be None when feedback not exists") | |||
| else: | |||
| feedback = MessageFeedback( | |||
| app_id=app_model.id, | |||
| conversation_id=message.conversation_id, | |||
| message_id=message.id, | |||
| rating=args['rating'], | |||
| from_source='admin', | |||
| from_account_id=current_user.id | |||
| rating=args["rating"], | |||
| from_source="admin", | |||
| from_account_id=current_user.id, | |||
| ) | |||
| db.session.add(feedback) | |||
| db.session.commit() | |||
| return {'result': 'success'} | |||
| return {"result": "success"} | |||
| class MessageAnnotationApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @cloud_edition_billing_resource_check('annotation') | |||
| @cloud_edition_billing_resource_check("annotation") | |||
| @get_app_model | |||
| @marshal_with(annotation_fields) | |||
| def post(self, app_model): | |||
| @@ -153,10 +164,10 @@ class MessageAnnotationApi(Resource): | |||
| raise Forbidden() | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('message_id', required=False, type=uuid_value, location='json') | |||
| parser.add_argument('question', required=True, type=str, location='json') | |||
| parser.add_argument('answer', required=True, type=str, location='json') | |||
| parser.add_argument('annotation_reply', required=False, type=dict, location='json') | |||
| parser.add_argument("message_id", required=False, type=uuid_value, location="json") | |||
| parser.add_argument("question", required=True, type=str, location="json") | |||
| parser.add_argument("answer", required=True, type=str, location="json") | |||
| parser.add_argument("annotation_reply", required=False, type=dict, location="json") | |||
| args = parser.parse_args() | |||
| annotation = AppAnnotationService.up_insert_app_annotation_from_message(args, app_model.id) | |||
| @@ -169,11 +180,9 @@ class MessageAnnotationCountApi(Resource): | |||
| @account_initialization_required | |||
| @get_app_model | |||
| def get(self, app_model): | |||
| count = db.session.query(MessageAnnotation).filter( | |||
| MessageAnnotation.app_id == app_model.id | |||
| ).count() | |||
| count = db.session.query(MessageAnnotation).filter(MessageAnnotation.app_id == app_model.id).count() | |||
| return {'count': count} | |||
| return {"count": count} | |||
| class MessageSuggestedQuestionApi(Resource): | |||
| @@ -186,10 +195,7 @@ class MessageSuggestedQuestionApi(Resource): | |||
| try: | |||
| questions = MessageService.get_suggested_questions_after_answer( | |||
| app_model=app_model, | |||
| message_id=message_id, | |||
| user=current_user, | |||
| invoke_from=InvokeFrom.DEBUGGER | |||
| app_model=app_model, message_id=message_id, user=current_user, invoke_from=InvokeFrom.DEBUGGER | |||
| ) | |||
| except MessageNotExistsError: | |||
| raise NotFound("Message not found") | |||
| @@ -209,7 +215,7 @@ class MessageSuggestedQuestionApi(Resource): | |||
| logging.exception("internal server error.") | |||
| raise InternalServerError() | |||
| return {'data': questions} | |||
| return {"data": questions} | |||
| class MessageApi(Resource): | |||
| @@ -221,10 +227,7 @@ class MessageApi(Resource): | |||
| def get(self, app_model, message_id): | |||
| message_id = str(message_id) | |||
| message = db.session.query(Message).filter( | |||
| Message.id == message_id, | |||
| Message.app_id == app_model.id | |||
| ).first() | |||
| message = db.session.query(Message).filter(Message.id == message_id, Message.app_id == app_model.id).first() | |||
| if not message: | |||
| raise NotFound("Message Not Exists.") | |||
| @@ -232,9 +235,9 @@ class MessageApi(Resource): | |||
| return message | |||
| api.add_resource(MessageSuggestedQuestionApi, '/apps/<uuid:app_id>/chat-messages/<uuid:message_id>/suggested-questions') | |||
| api.add_resource(ChatMessageListApi, '/apps/<uuid:app_id>/chat-messages', endpoint='console_chat_messages') | |||
| api.add_resource(MessageFeedbackApi, '/apps/<uuid:app_id>/feedbacks') | |||
| api.add_resource(MessageAnnotationApi, '/apps/<uuid:app_id>/annotations') | |||
| api.add_resource(MessageAnnotationCountApi, '/apps/<uuid:app_id>/annotations/count') | |||
| api.add_resource(MessageApi, '/apps/<uuid:app_id>/messages/<uuid:message_id>', endpoint='console_message') | |||
| api.add_resource(MessageSuggestedQuestionApi, "/apps/<uuid:app_id>/chat-messages/<uuid:message_id>/suggested-questions") | |||
| api.add_resource(ChatMessageListApi, "/apps/<uuid:app_id>/chat-messages", endpoint="console_chat_messages") | |||
| api.add_resource(MessageFeedbackApi, "/apps/<uuid:app_id>/feedbacks") | |||
| api.add_resource(MessageAnnotationApi, "/apps/<uuid:app_id>/annotations") | |||
| api.add_resource(MessageAnnotationCountApi, "/apps/<uuid:app_id>/annotations/count") | |||
| api.add_resource(MessageApi, "/apps/<uuid:app_id>/messages/<uuid:message_id>", endpoint="console_message") | |||
| @@ -19,19 +19,15 @@ from services.app_model_config_service import AppModelConfigService | |||
| class ModelConfigResource(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @get_app_model(mode=[AppMode.AGENT_CHAT, AppMode.CHAT, AppMode.COMPLETION]) | |||
| def post(self, app_model): | |||
| """Modify app model config""" | |||
| # validate config | |||
| model_configuration = AppModelConfigService.validate_configuration( | |||
| tenant_id=current_user.current_tenant_id, | |||
| config=request.json, | |||
| app_mode=AppMode.value_of(app_model.mode) | |||
| tenant_id=current_user.current_tenant_id, config=request.json, app_mode=AppMode.value_of(app_model.mode) | |||
| ) | |||
| new_app_model_config = AppModelConfig( | |||
| @@ -41,15 +37,15 @@ class ModelConfigResource(Resource): | |||
| if app_model.mode == AppMode.AGENT_CHAT.value or app_model.is_agent: | |||
| # get original app model config | |||
| original_app_model_config: AppModelConfig = db.session.query(AppModelConfig).filter( | |||
| AppModelConfig.id == app_model.app_model_config_id | |||
| ).first() | |||
| original_app_model_config: AppModelConfig = ( | |||
| db.session.query(AppModelConfig).filter(AppModelConfig.id == app_model.app_model_config_id).first() | |||
| ) | |||
| agent_mode = original_app_model_config.agent_mode_dict | |||
| # decrypt agent tool parameters if it's secret-input | |||
| parameter_map = {} | |||
| masked_parameter_map = {} | |||
| tool_map = {} | |||
| for tool in agent_mode.get('tools') or []: | |||
| for tool in agent_mode.get("tools") or []: | |||
| if not isinstance(tool, dict) or len(tool.keys()) <= 3: | |||
| continue | |||
| @@ -66,7 +62,7 @@ class ModelConfigResource(Resource): | |||
| tool_runtime=tool_runtime, | |||
| provider_name=agent_tool_entity.provider_id, | |||
| provider_type=agent_tool_entity.provider_type, | |||
| identity_id=f'AGENT.{app_model.id}' | |||
| identity_id=f"AGENT.{app_model.id}", | |||
| ) | |||
| except Exception as e: | |||
| continue | |||
| @@ -79,18 +75,18 @@ class ModelConfigResource(Resource): | |||
| parameters = {} | |||
| masked_parameter = {} | |||
| key = f'{agent_tool_entity.provider_id}.{agent_tool_entity.provider_type}.{agent_tool_entity.tool_name}' | |||
| key = f"{agent_tool_entity.provider_id}.{agent_tool_entity.provider_type}.{agent_tool_entity.tool_name}" | |||
| masked_parameter_map[key] = masked_parameter | |||
| parameter_map[key] = parameters | |||
| tool_map[key] = tool_runtime | |||
| # encrypt agent tool parameters if it's secret-input | |||
| agent_mode = new_app_model_config.agent_mode_dict | |||
| for tool in agent_mode.get('tools') or []: | |||
| for tool in agent_mode.get("tools") or []: | |||
| agent_tool_entity = AgentToolEntity(**tool) | |||
| # get tool | |||
| key = f'{agent_tool_entity.provider_id}.{agent_tool_entity.provider_type}.{agent_tool_entity.tool_name}' | |||
| key = f"{agent_tool_entity.provider_id}.{agent_tool_entity.provider_type}.{agent_tool_entity.tool_name}" | |||
| if key in tool_map: | |||
| tool_runtime = tool_map[key] | |||
| else: | |||
| @@ -108,7 +104,7 @@ class ModelConfigResource(Resource): | |||
| tool_runtime=tool_runtime, | |||
| provider_name=agent_tool_entity.provider_id, | |||
| provider_type=agent_tool_entity.provider_type, | |||
| identity_id=f'AGENT.{app_model.id}' | |||
| identity_id=f"AGENT.{app_model.id}", | |||
| ) | |||
| manager.delete_tool_parameters_cache() | |||
| @@ -116,15 +112,17 @@ class ModelConfigResource(Resource): | |||
| if agent_tool_entity.tool_parameters: | |||
| if key not in masked_parameter_map: | |||
| continue | |||
| for masked_key, masked_value in masked_parameter_map[key].items(): | |||
| if masked_key in agent_tool_entity.tool_parameters and \ | |||
| agent_tool_entity.tool_parameters[masked_key] == masked_value: | |||
| if ( | |||
| masked_key in agent_tool_entity.tool_parameters | |||
| and agent_tool_entity.tool_parameters[masked_key] == masked_value | |||
| ): | |||
| agent_tool_entity.tool_parameters[masked_key] = parameter_map[key].get(masked_key) | |||
| # encrypt parameters | |||
| if agent_tool_entity.tool_parameters: | |||
| tool['tool_parameters'] = manager.encrypt_tool_parameters(agent_tool_entity.tool_parameters or {}) | |||
| tool["tool_parameters"] = manager.encrypt_tool_parameters(agent_tool_entity.tool_parameters or {}) | |||
| # update app model config | |||
| new_app_model_config.agent_mode = json.dumps(agent_mode) | |||
| @@ -135,12 +133,9 @@ class ModelConfigResource(Resource): | |||
| app_model.app_model_config_id = new_app_model_config.id | |||
| db.session.commit() | |||
| app_model_config_was_updated.send( | |||
| app_model, | |||
| app_model_config=new_app_model_config | |||
| ) | |||
| app_model_config_was_updated.send(app_model, app_model_config=new_app_model_config) | |||
| return {'result': 'success'} | |||
| return {"result": "success"} | |||
| api.add_resource(ModelConfigResource, '/apps/<uuid:app_id>/model-config') | |||
| api.add_resource(ModelConfigResource, "/apps/<uuid:app_id>/model-config") | |||
| @@ -18,13 +18,11 @@ class TraceAppConfigApi(Resource): | |||
| @account_initialization_required | |||
| def get(self, app_id): | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('tracing_provider', type=str, required=True, location='args') | |||
| parser.add_argument("tracing_provider", type=str, required=True, location="args") | |||
| args = parser.parse_args() | |||
| try: | |||
| trace_config = OpsService.get_tracing_app_config( | |||
| app_id=app_id, tracing_provider=args['tracing_provider'] | |||
| ) | |||
| trace_config = OpsService.get_tracing_app_config(app_id=app_id, tracing_provider=args["tracing_provider"]) | |||
| if not trace_config: | |||
| return {"has_not_configured": True} | |||
| return trace_config | |||
| @@ -37,19 +35,17 @@ class TraceAppConfigApi(Resource): | |||
| def post(self, app_id): | |||
| """Create a new trace app configuration""" | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('tracing_provider', type=str, required=True, location='json') | |||
| parser.add_argument('tracing_config', type=dict, required=True, location='json') | |||
| parser.add_argument("tracing_provider", type=str, required=True, location="json") | |||
| parser.add_argument("tracing_config", type=dict, required=True, location="json") | |||
| args = parser.parse_args() | |||
| try: | |||
| result = OpsService.create_tracing_app_config( | |||
| app_id=app_id, | |||
| tracing_provider=args['tracing_provider'], | |||
| tracing_config=args['tracing_config'] | |||
| app_id=app_id, tracing_provider=args["tracing_provider"], tracing_config=args["tracing_config"] | |||
| ) | |||
| if not result: | |||
| raise TracingConfigIsExist() | |||
| if result.get('error'): | |||
| if result.get("error"): | |||
| raise TracingConfigCheckError() | |||
| return result | |||
| except Exception as e: | |||
| @@ -61,15 +57,13 @@ class TraceAppConfigApi(Resource): | |||
| def patch(self, app_id): | |||
| """Update an existing trace app configuration""" | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('tracing_provider', type=str, required=True, location='json') | |||
| parser.add_argument('tracing_config', type=dict, required=True, location='json') | |||
| parser.add_argument("tracing_provider", type=str, required=True, location="json") | |||
| parser.add_argument("tracing_config", type=dict, required=True, location="json") | |||
| args = parser.parse_args() | |||
| try: | |||
| result = OpsService.update_tracing_app_config( | |||
| app_id=app_id, | |||
| tracing_provider=args['tracing_provider'], | |||
| tracing_config=args['tracing_config'] | |||
| app_id=app_id, tracing_provider=args["tracing_provider"], tracing_config=args["tracing_config"] | |||
| ) | |||
| if not result: | |||
| raise TracingConfigNotExist() | |||
| @@ -83,14 +77,11 @@ class TraceAppConfigApi(Resource): | |||
| def delete(self, app_id): | |||
| """Delete an existing trace app configuration""" | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('tracing_provider', type=str, required=True, location='args') | |||
| parser.add_argument("tracing_provider", type=str, required=True, location="args") | |||
| args = parser.parse_args() | |||
| try: | |||
| result = OpsService.delete_tracing_app_config( | |||
| app_id=app_id, | |||
| tracing_provider=args['tracing_provider'] | |||
| ) | |||
| result = OpsService.delete_tracing_app_config(app_id=app_id, tracing_provider=args["tracing_provider"]) | |||
| if not result: | |||
| raise TracingConfigNotExist() | |||
| return {"result": "success"} | |||
| @@ -98,4 +89,4 @@ class TraceAppConfigApi(Resource): | |||
| raise e | |||
| api.add_resource(TraceAppConfigApi, '/apps/<uuid:app_id>/trace-config') | |||
| api.add_resource(TraceAppConfigApi, "/apps/<uuid:app_id>/trace-config") | |||
| @@ -15,23 +15,23 @@ from models.model import Site | |||
| def parse_app_site_args(): | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('title', type=str, required=False, location='json') | |||
| parser.add_argument('icon_type', type=str, required=False, location='json') | |||
| parser.add_argument('icon', type=str, required=False, location='json') | |||
| parser.add_argument('icon_background', type=str, required=False, location='json') | |||
| parser.add_argument('description', type=str, required=False, location='json') | |||
| parser.add_argument('default_language', type=supported_language, required=False, location='json') | |||
| parser.add_argument('chat_color_theme', type=str, required=False, location='json') | |||
| parser.add_argument('chat_color_theme_inverted', type=bool, required=False, location='json') | |||
| parser.add_argument('customize_domain', type=str, required=False, location='json') | |||
| parser.add_argument('copyright', type=str, required=False, location='json') | |||
| parser.add_argument('privacy_policy', type=str, required=False, location='json') | |||
| parser.add_argument('custom_disclaimer', type=str, required=False, location='json') | |||
| parser.add_argument('customize_token_strategy', type=str, choices=['must', 'allow', 'not_allow'], | |||
| required=False, | |||
| location='json') | |||
| parser.add_argument('prompt_public', type=bool, required=False, location='json') | |||
| parser.add_argument('show_workflow_steps', type=bool, required=False, location='json') | |||
| parser.add_argument("title", type=str, required=False, location="json") | |||
| parser.add_argument("icon_type", type=str, required=False, location="json") | |||
| parser.add_argument("icon", type=str, required=False, location="json") | |||
| parser.add_argument("icon_background", type=str, required=False, location="json") | |||
| parser.add_argument("description", type=str, required=False, location="json") | |||
| parser.add_argument("default_language", type=supported_language, required=False, location="json") | |||
| parser.add_argument("chat_color_theme", type=str, required=False, location="json") | |||
| parser.add_argument("chat_color_theme_inverted", type=bool, required=False, location="json") | |||
| parser.add_argument("customize_domain", type=str, required=False, location="json") | |||
| parser.add_argument("copyright", type=str, required=False, location="json") | |||
| parser.add_argument("privacy_policy", type=str, required=False, location="json") | |||
| parser.add_argument("custom_disclaimer", type=str, required=False, location="json") | |||
| parser.add_argument( | |||
| "customize_token_strategy", type=str, choices=["must", "allow", "not_allow"], required=False, location="json" | |||
| ) | |||
| parser.add_argument("prompt_public", type=bool, required=False, location="json") | |||
| parser.add_argument("show_workflow_steps", type=bool, required=False, location="json") | |||
| return parser.parse_args() | |||
| @@ -48,26 +48,24 @@ class AppSite(Resource): | |||
| if not current_user.is_editor: | |||
| raise Forbidden() | |||
| site = db.session.query(Site). \ | |||
| filter(Site.app_id == app_model.id). \ | |||
| one_or_404() | |||
| site = db.session.query(Site).filter(Site.app_id == app_model.id).one_or_404() | |||
| for attr_name in [ | |||
| 'title', | |||
| 'icon_type', | |||
| 'icon', | |||
| 'icon_background', | |||
| 'description', | |||
| 'default_language', | |||
| 'chat_color_theme', | |||
| 'chat_color_theme_inverted', | |||
| 'customize_domain', | |||
| 'copyright', | |||
| 'privacy_policy', | |||
| 'custom_disclaimer', | |||
| 'customize_token_strategy', | |||
| 'prompt_public', | |||
| 'show_workflow_steps' | |||
| "title", | |||
| "icon_type", | |||
| "icon", | |||
| "icon_background", | |||
| "description", | |||
| "default_language", | |||
| "chat_color_theme", | |||
| "chat_color_theme_inverted", | |||
| "customize_domain", | |||
| "copyright", | |||
| "privacy_policy", | |||
| "custom_disclaimer", | |||
| "customize_token_strategy", | |||
| "prompt_public", | |||
| "show_workflow_steps", | |||
| ]: | |||
| value = args.get(attr_name) | |||
| if value is not None: | |||
| @@ -79,7 +77,6 @@ class AppSite(Resource): | |||
| class AppSiteAccessTokenReset(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @@ -101,5 +98,5 @@ class AppSiteAccessTokenReset(Resource): | |||
| return site | |||
| api.add_resource(AppSite, '/apps/<uuid:app_id>/site') | |||
| api.add_resource(AppSiteAccessTokenReset, '/apps/<uuid:app_id>/site/access-token-reset') | |||
| api.add_resource(AppSite, "/apps/<uuid:app_id>/site") | |||
| api.add_resource(AppSiteAccessTokenReset, "/apps/<uuid:app_id>/site/access-token-reset") | |||
| @@ -17,7 +17,6 @@ from models.model import AppMode | |||
| class DailyConversationStatistic(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @@ -26,58 +25,52 @@ class DailyConversationStatistic(Resource): | |||
| account = current_user | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args') | |||
| parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args') | |||
| parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") | |||
| parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") | |||
| args = parser.parse_args() | |||
| sql_query = ''' | |||
| sql_query = """ | |||
| SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, count(distinct messages.conversation_id) AS conversation_count | |||
| FROM messages where app_id = :app_id | |||
| ''' | |||
| arg_dict = {'tz': account.timezone, 'app_id': app_model.id} | |||
| """ | |||
| arg_dict = {"tz": account.timezone, "app_id": app_model.id} | |||
| timezone = pytz.timezone(account.timezone) | |||
| utc_timezone = pytz.utc | |||
| if args['start']: | |||
| start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M') | |||
| if args["start"]: | |||
| start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M") | |||
| start_datetime = start_datetime.replace(second=0) | |||
| start_datetime_timezone = timezone.localize(start_datetime) | |||
| start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) | |||
| sql_query += ' and created_at >= :start' | |||
| arg_dict['start'] = start_datetime_utc | |||
| sql_query += " and created_at >= :start" | |||
| arg_dict["start"] = start_datetime_utc | |||
| if args['end']: | |||
| end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M') | |||
| if args["end"]: | |||
| end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M") | |||
| end_datetime = end_datetime.replace(second=0) | |||
| end_datetime_timezone = timezone.localize(end_datetime) | |||
| end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) | |||
| sql_query += ' and created_at < :end' | |||
| arg_dict['end'] = end_datetime_utc | |||
| sql_query += " and created_at < :end" | |||
| arg_dict["end"] = end_datetime_utc | |||
| sql_query += ' GROUP BY date order by date' | |||
| sql_query += " GROUP BY date order by date" | |||
| response_data = [] | |||
| with db.engine.begin() as conn: | |||
| rs = conn.execute(db.text(sql_query), arg_dict) | |||
| for i in rs: | |||
| response_data.append({ | |||
| 'date': str(i.date), | |||
| 'conversation_count': i.conversation_count | |||
| }) | |||
| response_data.append({"date": str(i.date), "conversation_count": i.conversation_count}) | |||
| return jsonify({ | |||
| 'data': response_data | |||
| }) | |||
| return jsonify({"data": response_data}) | |||
| class DailyTerminalsStatistic(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @@ -86,54 +79,49 @@ class DailyTerminalsStatistic(Resource): | |||
| account = current_user | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args') | |||
| parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args') | |||
| parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") | |||
| parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") | |||
| args = parser.parse_args() | |||
| sql_query = ''' | |||
| sql_query = """ | |||
| SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, count(distinct messages.from_end_user_id) AS terminal_count | |||
| FROM messages where app_id = :app_id | |||
| ''' | |||
| arg_dict = {'tz': account.timezone, 'app_id': app_model.id} | |||
| """ | |||
| arg_dict = {"tz": account.timezone, "app_id": app_model.id} | |||
| timezone = pytz.timezone(account.timezone) | |||
| utc_timezone = pytz.utc | |||
| if args['start']: | |||
| start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M') | |||
| if args["start"]: | |||
| start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M") | |||
| start_datetime = start_datetime.replace(second=0) | |||
| start_datetime_timezone = timezone.localize(start_datetime) | |||
| start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) | |||
| sql_query += ' and created_at >= :start' | |||
| arg_dict['start'] = start_datetime_utc | |||
| sql_query += " and created_at >= :start" | |||
| arg_dict["start"] = start_datetime_utc | |||
| if args['end']: | |||
| end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M') | |||
| if args["end"]: | |||
| end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M") | |||
| end_datetime = end_datetime.replace(second=0) | |||
| end_datetime_timezone = timezone.localize(end_datetime) | |||
| end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) | |||
| sql_query += ' and created_at < :end' | |||
| arg_dict['end'] = end_datetime_utc | |||
| sql_query += " and created_at < :end" | |||
| arg_dict["end"] = end_datetime_utc | |||
| sql_query += ' GROUP BY date order by date' | |||
| sql_query += " GROUP BY date order by date" | |||
| response_data = [] | |||
| with db.engine.begin() as conn: | |||
| rs = conn.execute(db.text(sql_query), arg_dict) | |||
| rs = conn.execute(db.text(sql_query), arg_dict) | |||
| for i in rs: | |||
| response_data.append({ | |||
| 'date': str(i.date), | |||
| 'terminal_count': i.terminal_count | |||
| }) | |||
| response_data.append({"date": str(i.date), "terminal_count": i.terminal_count}) | |||
| return jsonify({ | |||
| 'data': response_data | |||
| }) | |||
| return jsonify({"data": response_data}) | |||
| class DailyTokenCostStatistic(Resource): | |||
| @@ -145,58 +133,53 @@ class DailyTokenCostStatistic(Resource): | |||
| account = current_user | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args') | |||
| parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args') | |||
| parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") | |||
| parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") | |||
| args = parser.parse_args() | |||
| sql_query = ''' | |||
| sql_query = """ | |||
| SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, | |||
| (sum(messages.message_tokens) + sum(messages.answer_tokens)) as token_count, | |||
| sum(total_price) as total_price | |||
| FROM messages where app_id = :app_id | |||
| ''' | |||
| arg_dict = {'tz': account.timezone, 'app_id': app_model.id} | |||
| """ | |||
| arg_dict = {"tz": account.timezone, "app_id": app_model.id} | |||
| timezone = pytz.timezone(account.timezone) | |||
| utc_timezone = pytz.utc | |||
| if args['start']: | |||
| start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M') | |||
| if args["start"]: | |||
| start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M") | |||
| start_datetime = start_datetime.replace(second=0) | |||
| start_datetime_timezone = timezone.localize(start_datetime) | |||
| start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) | |||
| sql_query += ' and created_at >= :start' | |||
| arg_dict['start'] = start_datetime_utc | |||
| sql_query += " and created_at >= :start" | |||
| arg_dict["start"] = start_datetime_utc | |||
| if args['end']: | |||
| end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M') | |||
| if args["end"]: | |||
| end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M") | |||
| end_datetime = end_datetime.replace(second=0) | |||
| end_datetime_timezone = timezone.localize(end_datetime) | |||
| end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) | |||
| sql_query += ' and created_at < :end' | |||
| arg_dict['end'] = end_datetime_utc | |||
| sql_query += " and created_at < :end" | |||
| arg_dict["end"] = end_datetime_utc | |||
| sql_query += ' GROUP BY date order by date' | |||
| sql_query += " GROUP BY date order by date" | |||
| response_data = [] | |||
| with db.engine.begin() as conn: | |||
| rs = conn.execute(db.text(sql_query), arg_dict) | |||
| for i in rs: | |||
| response_data.append({ | |||
| 'date': str(i.date), | |||
| 'token_count': i.token_count, | |||
| 'total_price': i.total_price, | |||
| 'currency': 'USD' | |||
| }) | |||
| response_data.append( | |||
| {"date": str(i.date), "token_count": i.token_count, "total_price": i.total_price, "currency": "USD"} | |||
| ) | |||
| return jsonify({ | |||
| 'data': response_data | |||
| }) | |||
| return jsonify({"data": response_data}) | |||
| class AverageSessionInteractionStatistic(Resource): | |||
| @@ -208,8 +191,8 @@ class AverageSessionInteractionStatistic(Resource): | |||
| account = current_user | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args') | |||
| parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args') | |||
| parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") | |||
| parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") | |||
| args = parser.parse_args() | |||
| sql_query = """SELECT date(DATE_TRUNC('day', c.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, | |||
| @@ -218,30 +201,30 @@ FROM (SELECT m.conversation_id, COUNT(m.id) AS message_count | |||
| FROM conversations c | |||
| JOIN messages m ON c.id = m.conversation_id | |||
| WHERE c.override_model_configs IS NULL AND c.app_id = :app_id""" | |||
| arg_dict = {'tz': account.timezone, 'app_id': app_model.id} | |||
| arg_dict = {"tz": account.timezone, "app_id": app_model.id} | |||
| timezone = pytz.timezone(account.timezone) | |||
| utc_timezone = pytz.utc | |||
| if args['start']: | |||
| start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M') | |||
| if args["start"]: | |||
| start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M") | |||
| start_datetime = start_datetime.replace(second=0) | |||
| start_datetime_timezone = timezone.localize(start_datetime) | |||
| start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) | |||
| sql_query += ' and c.created_at >= :start' | |||
| arg_dict['start'] = start_datetime_utc | |||
| sql_query += " and c.created_at >= :start" | |||
| arg_dict["start"] = start_datetime_utc | |||
| if args['end']: | |||
| end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M') | |||
| if args["end"]: | |||
| end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M") | |||
| end_datetime = end_datetime.replace(second=0) | |||
| end_datetime_timezone = timezone.localize(end_datetime) | |||
| end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) | |||
| sql_query += ' and c.created_at < :end' | |||
| arg_dict['end'] = end_datetime_utc | |||
| sql_query += " and c.created_at < :end" | |||
| arg_dict["end"] = end_datetime_utc | |||
| sql_query += """ | |||
| GROUP BY m.conversation_id) subquery | |||
| @@ -250,18 +233,15 @@ GROUP BY date | |||
| ORDER BY date""" | |||
| response_data = [] | |||
| with db.engine.begin() as conn: | |||
| rs = conn.execute(db.text(sql_query), arg_dict) | |||
| for i in rs: | |||
| response_data.append({ | |||
| 'date': str(i.date), | |||
| 'interactions': float(i.interactions.quantize(Decimal('0.01'))) | |||
| }) | |||
| response_data.append( | |||
| {"date": str(i.date), "interactions": float(i.interactions.quantize(Decimal("0.01")))} | |||
| ) | |||
| return jsonify({ | |||
| 'data': response_data | |||
| }) | |||
| return jsonify({"data": response_data}) | |||
| class UserSatisfactionRateStatistic(Resource): | |||
| @@ -273,57 +253,57 @@ class UserSatisfactionRateStatistic(Resource): | |||
| account = current_user | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args') | |||
| parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args') | |||
| parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") | |||
| parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") | |||
| args = parser.parse_args() | |||
| sql_query = ''' | |||
| sql_query = """ | |||
| SELECT date(DATE_TRUNC('day', m.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, | |||
| COUNT(m.id) as message_count, COUNT(mf.id) as feedback_count | |||
| FROM messages m | |||
| LEFT JOIN message_feedbacks mf on mf.message_id=m.id and mf.rating='like' | |||
| WHERE m.app_id = :app_id | |||
| ''' | |||
| arg_dict = {'tz': account.timezone, 'app_id': app_model.id} | |||
| """ | |||
| arg_dict = {"tz": account.timezone, "app_id": app_model.id} | |||
| timezone = pytz.timezone(account.timezone) | |||
| utc_timezone = pytz.utc | |||
| if args['start']: | |||
| start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M') | |||
| if args["start"]: | |||
| start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M") | |||
| start_datetime = start_datetime.replace(second=0) | |||
| start_datetime_timezone = timezone.localize(start_datetime) | |||
| start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) | |||
| sql_query += ' and m.created_at >= :start' | |||
| arg_dict['start'] = start_datetime_utc | |||
| sql_query += " and m.created_at >= :start" | |||
| arg_dict["start"] = start_datetime_utc | |||
| if args['end']: | |||
| end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M') | |||
| if args["end"]: | |||
| end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M") | |||
| end_datetime = end_datetime.replace(second=0) | |||
| end_datetime_timezone = timezone.localize(end_datetime) | |||
| end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) | |||
| sql_query += ' and m.created_at < :end' | |||
| arg_dict['end'] = end_datetime_utc | |||
| sql_query += " and m.created_at < :end" | |||
| arg_dict["end"] = end_datetime_utc | |||
| sql_query += ' GROUP BY date order by date' | |||
| sql_query += " GROUP BY date order by date" | |||
| response_data = [] | |||
| with db.engine.begin() as conn: | |||
| rs = conn.execute(db.text(sql_query), arg_dict) | |||
| for i in rs: | |||
| response_data.append({ | |||
| 'date': str(i.date), | |||
| 'rate': round((i.feedback_count * 1000 / i.message_count) if i.message_count > 0 else 0, 2), | |||
| }) | |||
| response_data.append( | |||
| { | |||
| "date": str(i.date), | |||
| "rate": round((i.feedback_count * 1000 / i.message_count) if i.message_count > 0 else 0, 2), | |||
| } | |||
| ) | |||
| return jsonify({ | |||
| 'data': response_data | |||
| }) | |||
| return jsonify({"data": response_data}) | |||
| class AverageResponseTimeStatistic(Resource): | |||
| @@ -335,56 +315,51 @@ class AverageResponseTimeStatistic(Resource): | |||
| account = current_user | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args') | |||
| parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args') | |||
| parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") | |||
| parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") | |||
| args = parser.parse_args() | |||
| sql_query = ''' | |||
| sql_query = """ | |||
| SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, | |||
| AVG(provider_response_latency) as latency | |||
| FROM messages | |||
| WHERE app_id = :app_id | |||
| ''' | |||
| arg_dict = {'tz': account.timezone, 'app_id': app_model.id} | |||
| """ | |||
| arg_dict = {"tz": account.timezone, "app_id": app_model.id} | |||
| timezone = pytz.timezone(account.timezone) | |||
| utc_timezone = pytz.utc | |||
| if args['start']: | |||
| start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M') | |||
| if args["start"]: | |||
| start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M") | |||
| start_datetime = start_datetime.replace(second=0) | |||
| start_datetime_timezone = timezone.localize(start_datetime) | |||
| start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) | |||
| sql_query += ' and created_at >= :start' | |||
| arg_dict['start'] = start_datetime_utc | |||
| sql_query += " and created_at >= :start" | |||
| arg_dict["start"] = start_datetime_utc | |||
| if args['end']: | |||
| end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M') | |||
| if args["end"]: | |||
| end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M") | |||
| end_datetime = end_datetime.replace(second=0) | |||
| end_datetime_timezone = timezone.localize(end_datetime) | |||
| end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) | |||
| sql_query += ' and created_at < :end' | |||
| arg_dict['end'] = end_datetime_utc | |||
| sql_query += " and created_at < :end" | |||
| arg_dict["end"] = end_datetime_utc | |||
| sql_query += ' GROUP BY date order by date' | |||
| sql_query += " GROUP BY date order by date" | |||
| response_data = [] | |||
| with db.engine.begin() as conn: | |||
| rs = conn.execute(db.text(sql_query), arg_dict) | |||
| rs = conn.execute(db.text(sql_query), arg_dict) | |||
| for i in rs: | |||
| response_data.append({ | |||
| 'date': str(i.date), | |||
| 'latency': round(i.latency * 1000, 4) | |||
| }) | |||
| response_data.append({"date": str(i.date), "latency": round(i.latency * 1000, 4)}) | |||
| return jsonify({ | |||
| 'data': response_data | |||
| }) | |||
| return jsonify({"data": response_data}) | |||
| class TokensPerSecondStatistic(Resource): | |||
| @@ -396,63 +371,58 @@ class TokensPerSecondStatistic(Resource): | |||
| account = current_user | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args') | |||
| parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args') | |||
| parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") | |||
| parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") | |||
| args = parser.parse_args() | |||
| sql_query = '''SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, | |||
| sql_query = """SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, | |||
| CASE | |||
| WHEN SUM(provider_response_latency) = 0 THEN 0 | |||
| ELSE (SUM(answer_tokens) / SUM(provider_response_latency)) | |||
| END as tokens_per_second | |||
| FROM messages | |||
| WHERE app_id = :app_id''' | |||
| arg_dict = {'tz': account.timezone, 'app_id': app_model.id} | |||
| WHERE app_id = :app_id""" | |||
| arg_dict = {"tz": account.timezone, "app_id": app_model.id} | |||
| timezone = pytz.timezone(account.timezone) | |||
| utc_timezone = pytz.utc | |||
| if args['start']: | |||
| start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M') | |||
| if args["start"]: | |||
| start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M") | |||
| start_datetime = start_datetime.replace(second=0) | |||
| start_datetime_timezone = timezone.localize(start_datetime) | |||
| start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) | |||
| sql_query += ' and created_at >= :start' | |||
| arg_dict['start'] = start_datetime_utc | |||
| sql_query += " and created_at >= :start" | |||
| arg_dict["start"] = start_datetime_utc | |||
| if args['end']: | |||
| end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M') | |||
| if args["end"]: | |||
| end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M") | |||
| end_datetime = end_datetime.replace(second=0) | |||
| end_datetime_timezone = timezone.localize(end_datetime) | |||
| end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) | |||
| sql_query += ' and created_at < :end' | |||
| arg_dict['end'] = end_datetime_utc | |||
| sql_query += " and created_at < :end" | |||
| arg_dict["end"] = end_datetime_utc | |||
| sql_query += ' GROUP BY date order by date' | |||
| sql_query += " GROUP BY date order by date" | |||
| response_data = [] | |||
| with db.engine.begin() as conn: | |||
| rs = conn.execute(db.text(sql_query), arg_dict) | |||
| for i in rs: | |||
| response_data.append({ | |||
| 'date': str(i.date), | |||
| 'tps': round(i.tokens_per_second, 4) | |||
| }) | |||
| return jsonify({ | |||
| 'data': response_data | |||
| }) | |||
| api.add_resource(DailyConversationStatistic, '/apps/<uuid:app_id>/statistics/daily-conversations') | |||
| api.add_resource(DailyTerminalsStatistic, '/apps/<uuid:app_id>/statistics/daily-end-users') | |||
| api.add_resource(DailyTokenCostStatistic, '/apps/<uuid:app_id>/statistics/token-costs') | |||
| api.add_resource(AverageSessionInteractionStatistic, '/apps/<uuid:app_id>/statistics/average-session-interactions') | |||
| api.add_resource(UserSatisfactionRateStatistic, '/apps/<uuid:app_id>/statistics/user-satisfaction-rate') | |||
| api.add_resource(AverageResponseTimeStatistic, '/apps/<uuid:app_id>/statistics/average-response-time') | |||
| api.add_resource(TokensPerSecondStatistic, '/apps/<uuid:app_id>/statistics/tokens-per-second') | |||
| response_data.append({"date": str(i.date), "tps": round(i.tokens_per_second, 4)}) | |||
| return jsonify({"data": response_data}) | |||
| api.add_resource(DailyConversationStatistic, "/apps/<uuid:app_id>/statistics/daily-conversations") | |||
| api.add_resource(DailyTerminalsStatistic, "/apps/<uuid:app_id>/statistics/daily-end-users") | |||
| api.add_resource(DailyTokenCostStatistic, "/apps/<uuid:app_id>/statistics/token-costs") | |||
| api.add_resource(AverageSessionInteractionStatistic, "/apps/<uuid:app_id>/statistics/average-session-interactions") | |||
| api.add_resource(UserSatisfactionRateStatistic, "/apps/<uuid:app_id>/statistics/user-satisfaction-rate") | |||
| api.add_resource(AverageResponseTimeStatistic, "/apps/<uuid:app_id>/statistics/average-response-time") | |||
| api.add_resource(TokensPerSecondStatistic, "/apps/<uuid:app_id>/statistics/tokens-per-second") | |||
| @@ -64,51 +64,51 @@ class DraftWorkflowApi(Resource): | |||
| # The role of the current user in the ta table must be admin, owner, or editor | |||
| if not current_user.is_editor: | |||
| raise Forbidden() | |||
| content_type = request.headers.get('Content-Type', '') | |||
| if 'application/json' in content_type: | |||
| content_type = request.headers.get("Content-Type", "") | |||
| if "application/json" in content_type: | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('graph', type=dict, required=True, nullable=False, location='json') | |||
| parser.add_argument('features', type=dict, required=True, nullable=False, location='json') | |||
| parser.add_argument('hash', type=str, required=False, location='json') | |||
| parser.add_argument("graph", type=dict, required=True, nullable=False, location="json") | |||
| parser.add_argument("features", type=dict, required=True, nullable=False, location="json") | |||
| parser.add_argument("hash", type=str, required=False, location="json") | |||
| # TODO: set this to required=True after frontend is updated | |||
| parser.add_argument('environment_variables', type=list, required=False, location='json') | |||
| parser.add_argument('conversation_variables', type=list, required=False, location='json') | |||
| parser.add_argument("environment_variables", type=list, required=False, location="json") | |||
| parser.add_argument("conversation_variables", type=list, required=False, location="json") | |||
| args = parser.parse_args() | |||
| elif 'text/plain' in content_type: | |||
| elif "text/plain" in content_type: | |||
| try: | |||
| data = json.loads(request.data.decode('utf-8')) | |||
| if 'graph' not in data or 'features' not in data: | |||
| raise ValueError('graph or features not found in data') | |||
| data = json.loads(request.data.decode("utf-8")) | |||
| if "graph" not in data or "features" not in data: | |||
| raise ValueError("graph or features not found in data") | |||
| if not isinstance(data.get('graph'), dict) or not isinstance(data.get('features'), dict): | |||
| raise ValueError('graph or features is not a dict') | |||
| if not isinstance(data.get("graph"), dict) or not isinstance(data.get("features"), dict): | |||
| raise ValueError("graph or features is not a dict") | |||
| args = { | |||
| 'graph': data.get('graph'), | |||
| 'features': data.get('features'), | |||
| 'hash': data.get('hash'), | |||
| 'environment_variables': data.get('environment_variables'), | |||
| 'conversation_variables': data.get('conversation_variables'), | |||
| "graph": data.get("graph"), | |||
| "features": data.get("features"), | |||
| "hash": data.get("hash"), | |||
| "environment_variables": data.get("environment_variables"), | |||
| "conversation_variables": data.get("conversation_variables"), | |||
| } | |||
| except json.JSONDecodeError: | |||
| return {'message': 'Invalid JSON data'}, 400 | |||
| return {"message": "Invalid JSON data"}, 400 | |||
| else: | |||
| abort(415) | |||
| workflow_service = WorkflowService() | |||
| try: | |||
| environment_variables_list = args.get('environment_variables') or [] | |||
| environment_variables_list = args.get("environment_variables") or [] | |||
| environment_variables = [factory.build_variable_from_mapping(obj) for obj in environment_variables_list] | |||
| conversation_variables_list = args.get('conversation_variables') or [] | |||
| conversation_variables_list = args.get("conversation_variables") or [] | |||
| conversation_variables = [factory.build_variable_from_mapping(obj) for obj in conversation_variables_list] | |||
| workflow = workflow_service.sync_draft_workflow( | |||
| app_model=app_model, | |||
| graph=args['graph'], | |||
| features=args['features'], | |||
| unique_hash=args.get('hash'), | |||
| graph=args["graph"], | |||
| features=args["features"], | |||
| unique_hash=args.get("hash"), | |||
| account=current_user, | |||
| environment_variables=environment_variables, | |||
| conversation_variables=conversation_variables, | |||
| @@ -119,7 +119,7 @@ class DraftWorkflowApi(Resource): | |||
| return { | |||
| "result": "success", | |||
| "hash": workflow.unique_hash, | |||
| "updated_at": TimestampField().format(workflow.updated_at or workflow.created_at) | |||
| "updated_at": TimestampField().format(workflow.updated_at or workflow.created_at), | |||
| } | |||
| @@ -138,13 +138,11 @@ class DraftWorkflowImportApi(Resource): | |||
| raise Forbidden() | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('data', type=str, required=True, nullable=False, location='json') | |||
| parser.add_argument("data", type=str, required=True, nullable=False, location="json") | |||
| args = parser.parse_args() | |||
| workflow = AppDslService.import_and_overwrite_workflow( | |||
| app_model=app_model, | |||
| data=args['data'], | |||
| account=current_user | |||
| app_model=app_model, data=args["data"], account=current_user | |||
| ) | |||
| return workflow | |||
| @@ -162,21 +160,17 @@ class AdvancedChatDraftWorkflowRunApi(Resource): | |||
| # The role of the current user in the ta table must be admin, owner, or editor | |||
| if not current_user.is_editor: | |||
| raise Forbidden() | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('inputs', type=dict, location='json') | |||
| parser.add_argument('query', type=str, required=True, location='json', default='') | |||
| parser.add_argument('files', type=list, location='json') | |||
| parser.add_argument('conversation_id', type=uuid_value, location='json') | |||
| parser.add_argument("inputs", type=dict, location="json") | |||
| parser.add_argument("query", type=str, required=True, location="json", default="") | |||
| parser.add_argument("files", type=list, location="json") | |||
| parser.add_argument("conversation_id", type=uuid_value, location="json") | |||
| args = parser.parse_args() | |||
| try: | |||
| response = AppGenerateService.generate( | |||
| app_model=app_model, | |||
| user=current_user, | |||
| args=args, | |||
| invoke_from=InvokeFrom.DEBUGGER, | |||
| streaming=True | |||
| app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.DEBUGGER, streaming=True | |||
| ) | |||
| return helper.compact_generate_response(response) | |||
| @@ -190,6 +184,7 @@ class AdvancedChatDraftWorkflowRunApi(Resource): | |||
| logging.exception("internal server error.") | |||
| raise InternalServerError() | |||
| class AdvancedChatDraftRunIterationNodeApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @@ -202,18 +197,14 @@ class AdvancedChatDraftRunIterationNodeApi(Resource): | |||
| # The role of the current user in the ta table must be admin, owner, or editor | |||
| if not current_user.is_editor: | |||
| raise Forbidden() | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('inputs', type=dict, location='json') | |||
| parser.add_argument("inputs", type=dict, location="json") | |||
| args = parser.parse_args() | |||
| try: | |||
| response = AppGenerateService.generate_single_iteration( | |||
| app_model=app_model, | |||
| user=current_user, | |||
| node_id=node_id, | |||
| args=args, | |||
| streaming=True | |||
| app_model=app_model, user=current_user, node_id=node_id, args=args, streaming=True | |||
| ) | |||
| return helper.compact_generate_response(response) | |||
| @@ -227,6 +218,7 @@ class AdvancedChatDraftRunIterationNodeApi(Resource): | |||
| logging.exception("internal server error.") | |||
| raise InternalServerError() | |||
| class WorkflowDraftRunIterationNodeApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @@ -239,18 +231,14 @@ class WorkflowDraftRunIterationNodeApi(Resource): | |||
| # The role of the current user in the ta table must be admin, owner, or editor | |||
| if not current_user.is_editor: | |||
| raise Forbidden() | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('inputs', type=dict, location='json') | |||
| parser.add_argument("inputs", type=dict, location="json") | |||
| args = parser.parse_args() | |||
| try: | |||
| response = AppGenerateService.generate_single_iteration( | |||
| app_model=app_model, | |||
| user=current_user, | |||
| node_id=node_id, | |||
| args=args, | |||
| streaming=True | |||
| app_model=app_model, user=current_user, node_id=node_id, args=args, streaming=True | |||
| ) | |||
| return helper.compact_generate_response(response) | |||
| @@ -264,6 +252,7 @@ class WorkflowDraftRunIterationNodeApi(Resource): | |||
| logging.exception("internal server error.") | |||
| raise InternalServerError() | |||
| class DraftWorkflowRunApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @@ -276,19 +265,15 @@ class DraftWorkflowRunApi(Resource): | |||
| # The role of the current user in the ta table must be admin, owner, or editor | |||
| if not current_user.is_editor: | |||
| raise Forbidden() | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('inputs', type=dict, required=True, nullable=False, location='json') | |||
| parser.add_argument('files', type=list, required=False, location='json') | |||
| parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") | |||
| parser.add_argument("files", type=list, required=False, location="json") | |||
| args = parser.parse_args() | |||
| try: | |||
| response = AppGenerateService.generate( | |||
| app_model=app_model, | |||
| user=current_user, | |||
| args=args, | |||
| invoke_from=InvokeFrom.DEBUGGER, | |||
| streaming=True | |||
| app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.DEBUGGER, streaming=True | |||
| ) | |||
| return helper.compact_generate_response(response) | |||
| @@ -311,12 +296,10 @@ class WorkflowTaskStopApi(Resource): | |||
| # The role of the current user in the ta table must be admin, owner, or editor | |||
| if not current_user.is_editor: | |||
| raise Forbidden() | |||
| AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, current_user.id) | |||
| return { | |||
| "result": "success" | |||
| } | |||
| return {"result": "success"} | |||
| class DraftWorkflowNodeRunApi(Resource): | |||
| @@ -332,24 +315,20 @@ class DraftWorkflowNodeRunApi(Resource): | |||
| # The role of the current user in the ta table must be admin, owner, or editor | |||
| if not current_user.is_editor: | |||
| raise Forbidden() | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('inputs', type=dict, required=True, nullable=False, location='json') | |||
| parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") | |||
| args = parser.parse_args() | |||
| workflow_service = WorkflowService() | |||
| workflow_node_execution = workflow_service.run_draft_workflow_node( | |||
| app_model=app_model, | |||
| node_id=node_id, | |||
| user_inputs=args.get('inputs'), | |||
| account=current_user | |||
| app_model=app_model, node_id=node_id, user_inputs=args.get("inputs"), account=current_user | |||
| ) | |||
| return workflow_node_execution | |||
| class PublishedWorkflowApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @@ -362,7 +341,7 @@ class PublishedWorkflowApi(Resource): | |||
| # The role of the current user in the ta table must be admin, owner, or editor | |||
| if not current_user.is_editor: | |||
| raise Forbidden() | |||
| # fetch published workflow by app_model | |||
| workflow_service = WorkflowService() | |||
| workflow = workflow_service.get_published_workflow(app_model=app_model) | |||
| @@ -381,14 +360,11 @@ class PublishedWorkflowApi(Resource): | |||
| # The role of the current user in the ta table must be admin, owner, or editor | |||
| if not current_user.is_editor: | |||
| raise Forbidden() | |||
| workflow_service = WorkflowService() | |||
| workflow = workflow_service.publish_workflow(app_model=app_model, account=current_user) | |||
| return { | |||
| "result": "success", | |||
| "created_at": TimestampField().format(workflow.created_at) | |||
| } | |||
| return {"result": "success", "created_at": TimestampField().format(workflow.created_at)} | |||
| class DefaultBlockConfigsApi(Resource): | |||
| @@ -403,7 +379,7 @@ class DefaultBlockConfigsApi(Resource): | |||
| # The role of the current user in the ta table must be admin, owner, or editor | |||
| if not current_user.is_editor: | |||
| raise Forbidden() | |||
| # Get default block configs | |||
| workflow_service = WorkflowService() | |||
| return workflow_service.get_default_block_configs() | |||
| @@ -421,24 +397,21 @@ class DefaultBlockConfigApi(Resource): | |||
| # The role of the current user in the ta table must be admin, owner, or editor | |||
| if not current_user.is_editor: | |||
| raise Forbidden() | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('q', type=str, location='args') | |||
| parser.add_argument("q", type=str, location="args") | |||
| args = parser.parse_args() | |||
| filters = None | |||
| if args.get('q'): | |||
| if args.get("q"): | |||
| try: | |||
| filters = json.loads(args.get('q')) | |||
| filters = json.loads(args.get("q")) | |||
| except json.JSONDecodeError: | |||
| raise ValueError('Invalid filters') | |||
| raise ValueError("Invalid filters") | |||
| # Get default block configs | |||
| workflow_service = WorkflowService() | |||
| return workflow_service.get_default_block_config( | |||
| node_type=block_type, | |||
| filters=filters | |||
| ) | |||
| return workflow_service.get_default_block_config(node_type=block_type, filters=filters) | |||
| class ConvertToWorkflowApi(Resource): | |||
| @@ -455,41 +428,43 @@ class ConvertToWorkflowApi(Resource): | |||
| # The role of the current user in the ta table must be admin, owner, or editor | |||
| if not current_user.is_editor: | |||
| raise Forbidden() | |||
| if request.data: | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('name', type=str, required=False, nullable=True, location='json') | |||
| parser.add_argument('icon_type', type=str, required=False, nullable=True, location='json') | |||
| parser.add_argument('icon', type=str, required=False, nullable=True, location='json') | |||
| parser.add_argument('icon_background', type=str, required=False, nullable=True, location='json') | |||
| parser.add_argument("name", type=str, required=False, nullable=True, location="json") | |||
| parser.add_argument("icon_type", type=str, required=False, nullable=True, location="json") | |||
| parser.add_argument("icon", type=str, required=False, nullable=True, location="json") | |||
| parser.add_argument("icon_background", type=str, required=False, nullable=True, location="json") | |||
| args = parser.parse_args() | |||
| else: | |||
| args = {} | |||
| # convert to workflow mode | |||
| workflow_service = WorkflowService() | |||
| new_app_model = workflow_service.convert_to_workflow( | |||
| app_model=app_model, | |||
| account=current_user, | |||
| args=args | |||
| ) | |||
| new_app_model = workflow_service.convert_to_workflow(app_model=app_model, account=current_user, args=args) | |||
| # return app id | |||
| return { | |||
| 'new_app_id': new_app_model.id, | |||
| "new_app_id": new_app_model.id, | |||
| } | |||
| api.add_resource(DraftWorkflowApi, '/apps/<uuid:app_id>/workflows/draft') | |||
| api.add_resource(DraftWorkflowImportApi, '/apps/<uuid:app_id>/workflows/draft/import') | |||
| api.add_resource(AdvancedChatDraftWorkflowRunApi, '/apps/<uuid:app_id>/advanced-chat/workflows/draft/run') | |||
| api.add_resource(DraftWorkflowRunApi, '/apps/<uuid:app_id>/workflows/draft/run') | |||
| api.add_resource(WorkflowTaskStopApi, '/apps/<uuid:app_id>/workflow-runs/tasks/<string:task_id>/stop') | |||
| api.add_resource(DraftWorkflowNodeRunApi, '/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/run') | |||
| api.add_resource(AdvancedChatDraftRunIterationNodeApi, '/apps/<uuid:app_id>/advanced-chat/workflows/draft/iteration/nodes/<string:node_id>/run') | |||
| api.add_resource(WorkflowDraftRunIterationNodeApi, '/apps/<uuid:app_id>/workflows/draft/iteration/nodes/<string:node_id>/run') | |||
| api.add_resource(PublishedWorkflowApi, '/apps/<uuid:app_id>/workflows/publish') | |||
| api.add_resource(DefaultBlockConfigsApi, '/apps/<uuid:app_id>/workflows/default-workflow-block-configs') | |||
| api.add_resource(DefaultBlockConfigApi, '/apps/<uuid:app_id>/workflows/default-workflow-block-configs' | |||
| '/<string:block_type>') | |||
| api.add_resource(ConvertToWorkflowApi, '/apps/<uuid:app_id>/convert-to-workflow') | |||
| api.add_resource(DraftWorkflowApi, "/apps/<uuid:app_id>/workflows/draft") | |||
| api.add_resource(DraftWorkflowImportApi, "/apps/<uuid:app_id>/workflows/draft/import") | |||
| api.add_resource(AdvancedChatDraftWorkflowRunApi, "/apps/<uuid:app_id>/advanced-chat/workflows/draft/run") | |||
| api.add_resource(DraftWorkflowRunApi, "/apps/<uuid:app_id>/workflows/draft/run") | |||
| api.add_resource(WorkflowTaskStopApi, "/apps/<uuid:app_id>/workflow-runs/tasks/<string:task_id>/stop") | |||
| api.add_resource(DraftWorkflowNodeRunApi, "/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/run") | |||
| api.add_resource( | |||
| AdvancedChatDraftRunIterationNodeApi, | |||
| "/apps/<uuid:app_id>/advanced-chat/workflows/draft/iteration/nodes/<string:node_id>/run", | |||
| ) | |||
| api.add_resource( | |||
| WorkflowDraftRunIterationNodeApi, "/apps/<uuid:app_id>/workflows/draft/iteration/nodes/<string:node_id>/run" | |||
| ) | |||
| api.add_resource(PublishedWorkflowApi, "/apps/<uuid:app_id>/workflows/publish") | |||
| api.add_resource(DefaultBlockConfigsApi, "/apps/<uuid:app_id>/workflows/default-workflow-block-configs") | |||
| api.add_resource( | |||
| DefaultBlockConfigApi, "/apps/<uuid:app_id>/workflows/default-workflow-block-configs" "/<string:block_type>" | |||
| ) | |||
| api.add_resource(ConvertToWorkflowApi, "/apps/<uuid:app_id>/convert-to-workflow") | |||
| @@ -22,20 +22,19 @@ class WorkflowAppLogApi(Resource): | |||
| Get workflow app logs | |||
| """ | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('keyword', type=str, location='args') | |||
| parser.add_argument('status', type=str, choices=['succeeded', 'failed', 'stopped'], location='args') | |||
| parser.add_argument('page', type=int_range(1, 99999), default=1, location='args') | |||
| parser.add_argument('limit', type=int_range(1, 100), default=20, location='args') | |||
| parser.add_argument("keyword", type=str, location="args") | |||
| parser.add_argument("status", type=str, choices=["succeeded", "failed", "stopped"], location="args") | |||
| parser.add_argument("page", type=int_range(1, 99999), default=1, location="args") | |||
| parser.add_argument("limit", type=int_range(1, 100), default=20, location="args") | |||
| args = parser.parse_args() | |||
| # get paginate workflow app logs | |||
| workflow_app_service = WorkflowAppService() | |||
| workflow_app_log_pagination = workflow_app_service.get_paginate_workflow_app_logs( | |||
| app_model=app_model, | |||
| args=args | |||
| app_model=app_model, args=args | |||
| ) | |||
| return workflow_app_log_pagination | |||
| api.add_resource(WorkflowAppLogApi, '/apps/<uuid:app_id>/workflow-app-logs') | |||
| api.add_resource(WorkflowAppLogApi, "/apps/<uuid:app_id>/workflow-app-logs") | |||
| @@ -28,15 +28,12 @@ class AdvancedChatAppWorkflowRunListApi(Resource): | |||
| Get advanced chat app workflow run list | |||
| """ | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('last_id', type=uuid_value, location='args') | |||
| parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args') | |||
| parser.add_argument("last_id", type=uuid_value, location="args") | |||
| parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") | |||
| args = parser.parse_args() | |||
| workflow_run_service = WorkflowRunService() | |||
| result = workflow_run_service.get_paginate_advanced_chat_workflow_runs( | |||
| app_model=app_model, | |||
| args=args | |||
| ) | |||
| result = workflow_run_service.get_paginate_advanced_chat_workflow_runs(app_model=app_model, args=args) | |||
| return result | |||
| @@ -52,15 +49,12 @@ class WorkflowRunListApi(Resource): | |||
| Get workflow run list | |||
| """ | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('last_id', type=uuid_value, location='args') | |||
| parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args') | |||
| parser.add_argument("last_id", type=uuid_value, location="args") | |||
| parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") | |||
| args = parser.parse_args() | |||
| workflow_run_service = WorkflowRunService() | |||
| result = workflow_run_service.get_paginate_workflow_runs( | |||
| app_model=app_model, | |||
| args=args | |||
| ) | |||
| result = workflow_run_service.get_paginate_workflow_runs(app_model=app_model, args=args) | |||
| return result | |||
| @@ -98,12 +92,10 @@ class WorkflowRunNodeExecutionListApi(Resource): | |||
| workflow_run_service = WorkflowRunService() | |||
| node_executions = workflow_run_service.get_workflow_run_node_executions(app_model=app_model, run_id=run_id) | |||
| return { | |||
| 'data': node_executions | |||
| } | |||
| return {"data": node_executions} | |||
| api.add_resource(AdvancedChatAppWorkflowRunListApi, '/apps/<uuid:app_id>/advanced-chat/workflow-runs') | |||
| api.add_resource(WorkflowRunListApi, '/apps/<uuid:app_id>/workflow-runs') | |||
| api.add_resource(WorkflowRunDetailApi, '/apps/<uuid:app_id>/workflow-runs/<uuid:run_id>') | |||
| api.add_resource(WorkflowRunNodeExecutionListApi, '/apps/<uuid:app_id>/workflow-runs/<uuid:run_id>/node-executions') | |||
| api.add_resource(AdvancedChatAppWorkflowRunListApi, "/apps/<uuid:app_id>/advanced-chat/workflow-runs") | |||
| api.add_resource(WorkflowRunListApi, "/apps/<uuid:app_id>/workflow-runs") | |||
| api.add_resource(WorkflowRunDetailApi, "/apps/<uuid:app_id>/workflow-runs/<uuid:run_id>") | |||
| api.add_resource(WorkflowRunNodeExecutionListApi, "/apps/<uuid:app_id>/workflow-runs/<uuid:run_id>/node-executions") | |||
| @@ -26,56 +26,56 @@ class WorkflowDailyRunsStatistic(Resource): | |||
| account = current_user | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args') | |||
| parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args') | |||
| parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") | |||
| parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") | |||
| args = parser.parse_args() | |||
| sql_query = ''' | |||
| sql_query = """ | |||
| SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, count(id) AS runs | |||
| FROM workflow_runs | |||
| WHERE app_id = :app_id | |||
| AND triggered_from = :triggered_from | |||
| ''' | |||
| arg_dict = {'tz': account.timezone, 'app_id': app_model.id, 'triggered_from': WorkflowRunTriggeredFrom.APP_RUN.value} | |||
| """ | |||
| arg_dict = { | |||
| "tz": account.timezone, | |||
| "app_id": app_model.id, | |||
| "triggered_from": WorkflowRunTriggeredFrom.APP_RUN.value, | |||
| } | |||
| timezone = pytz.timezone(account.timezone) | |||
| utc_timezone = pytz.utc | |||
| if args['start']: | |||
| start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M') | |||
| if args["start"]: | |||
| start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M") | |||
| start_datetime = start_datetime.replace(second=0) | |||
| start_datetime_timezone = timezone.localize(start_datetime) | |||
| start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) | |||
| sql_query += ' and created_at >= :start' | |||
| arg_dict['start'] = start_datetime_utc | |||
| sql_query += " and created_at >= :start" | |||
| arg_dict["start"] = start_datetime_utc | |||
| if args['end']: | |||
| end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M') | |||
| if args["end"]: | |||
| end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M") | |||
| end_datetime = end_datetime.replace(second=0) | |||
| end_datetime_timezone = timezone.localize(end_datetime) | |||
| end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) | |||
| sql_query += ' and created_at < :end' | |||
| arg_dict['end'] = end_datetime_utc | |||
| sql_query += " and created_at < :end" | |||
| arg_dict["end"] = end_datetime_utc | |||
| sql_query += ' GROUP BY date order by date' | |||
| sql_query += " GROUP BY date order by date" | |||
| response_data = [] | |||
| with db.engine.begin() as conn: | |||
| rs = conn.execute(db.text(sql_query), arg_dict) | |||
| for i in rs: | |||
| response_data.append({ | |||
| 'date': str(i.date), | |||
| 'runs': i.runs | |||
| }) | |||
| response_data.append({"date": str(i.date), "runs": i.runs}) | |||
| return jsonify({"data": response_data}) | |||
| return jsonify({ | |||
| 'data': response_data | |||
| }) | |||
| class WorkflowDailyTerminalsStatistic(Resource): | |||
| @setup_required | |||
| @@ -86,56 +86,56 @@ class WorkflowDailyTerminalsStatistic(Resource): | |||
| account = current_user | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args') | |||
| parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args') | |||
| parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") | |||
| parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") | |||
| args = parser.parse_args() | |||
| sql_query = ''' | |||
| sql_query = """ | |||
| SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, count(distinct workflow_runs.created_by) AS terminal_count | |||
| FROM workflow_runs | |||
| WHERE app_id = :app_id | |||
| AND triggered_from = :triggered_from | |||
| ''' | |||
| arg_dict = {'tz': account.timezone, 'app_id': app_model.id, 'triggered_from': WorkflowRunTriggeredFrom.APP_RUN.value} | |||
| """ | |||
| arg_dict = { | |||
| "tz": account.timezone, | |||
| "app_id": app_model.id, | |||
| "triggered_from": WorkflowRunTriggeredFrom.APP_RUN.value, | |||
| } | |||
| timezone = pytz.timezone(account.timezone) | |||
| utc_timezone = pytz.utc | |||
| if args['start']: | |||
| start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M') | |||
| if args["start"]: | |||
| start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M") | |||
| start_datetime = start_datetime.replace(second=0) | |||
| start_datetime_timezone = timezone.localize(start_datetime) | |||
| start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) | |||
| sql_query += ' and created_at >= :start' | |||
| arg_dict['start'] = start_datetime_utc | |||
| sql_query += " and created_at >= :start" | |||
| arg_dict["start"] = start_datetime_utc | |||
| if args['end']: | |||
| end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M') | |||
| if args["end"]: | |||
| end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M") | |||
| end_datetime = end_datetime.replace(second=0) | |||
| end_datetime_timezone = timezone.localize(end_datetime) | |||
| end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) | |||
| sql_query += ' and created_at < :end' | |||
| arg_dict['end'] = end_datetime_utc | |||
| sql_query += " and created_at < :end" | |||
| arg_dict["end"] = end_datetime_utc | |||
| sql_query += ' GROUP BY date order by date' | |||
| sql_query += " GROUP BY date order by date" | |||
| response_data = [] | |||
| with db.engine.begin() as conn: | |||
| rs = conn.execute(db.text(sql_query), arg_dict) | |||
| rs = conn.execute(db.text(sql_query), arg_dict) | |||
| for i in rs: | |||
| response_data.append({ | |||
| 'date': str(i.date), | |||
| 'terminal_count': i.terminal_count | |||
| }) | |||
| response_data.append({"date": str(i.date), "terminal_count": i.terminal_count}) | |||
| return jsonify({"data": response_data}) | |||
| return jsonify({ | |||
| 'data': response_data | |||
| }) | |||
| class WorkflowDailyTokenCostStatistic(Resource): | |||
| @setup_required | |||
| @@ -146,58 +146,63 @@ class WorkflowDailyTokenCostStatistic(Resource): | |||
| account = current_user | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args') | |||
| parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args') | |||
| parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") | |||
| parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") | |||
| args = parser.parse_args() | |||
| sql_query = ''' | |||
| sql_query = """ | |||
| SELECT | |||
| date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, | |||
| SUM(workflow_runs.total_tokens) as token_count | |||
| FROM workflow_runs | |||
| WHERE app_id = :app_id | |||
| AND triggered_from = :triggered_from | |||
| ''' | |||
| arg_dict = {'tz': account.timezone, 'app_id': app_model.id, 'triggered_from': WorkflowRunTriggeredFrom.APP_RUN.value} | |||
| """ | |||
| arg_dict = { | |||
| "tz": account.timezone, | |||
| "app_id": app_model.id, | |||
| "triggered_from": WorkflowRunTriggeredFrom.APP_RUN.value, | |||
| } | |||
| timezone = pytz.timezone(account.timezone) | |||
| utc_timezone = pytz.utc | |||
| if args['start']: | |||
| start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M') | |||
| if args["start"]: | |||
| start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M") | |||
| start_datetime = start_datetime.replace(second=0) | |||
| start_datetime_timezone = timezone.localize(start_datetime) | |||
| start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) | |||
| sql_query += ' and created_at >= :start' | |||
| arg_dict['start'] = start_datetime_utc | |||
| sql_query += " and created_at >= :start" | |||
| arg_dict["start"] = start_datetime_utc | |||
| if args['end']: | |||
| end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M') | |||
| if args["end"]: | |||
| end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M") | |||
| end_datetime = end_datetime.replace(second=0) | |||
| end_datetime_timezone = timezone.localize(end_datetime) | |||
| end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) | |||
| sql_query += ' and created_at < :end' | |||
| arg_dict['end'] = end_datetime_utc | |||
| sql_query += " and created_at < :end" | |||
| arg_dict["end"] = end_datetime_utc | |||
| sql_query += ' GROUP BY date order by date' | |||
| sql_query += " GROUP BY date order by date" | |||
| response_data = [] | |||
| with db.engine.begin() as conn: | |||
| rs = conn.execute(db.text(sql_query), arg_dict) | |||
| for i in rs: | |||
| response_data.append({ | |||
| 'date': str(i.date), | |||
| 'token_count': i.token_count, | |||
| }) | |||
| response_data.append( | |||
| { | |||
| "date": str(i.date), | |||
| "token_count": i.token_count, | |||
| } | |||
| ) | |||
| return jsonify({"data": response_data}) | |||
| return jsonify({ | |||
| 'data': response_data | |||
| }) | |||
| class WorkflowAverageAppInteractionStatistic(Resource): | |||
| @setup_required | |||
| @@ -208,8 +213,8 @@ class WorkflowAverageAppInteractionStatistic(Resource): | |||
| account = current_user | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args') | |||
| parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args') | |||
| parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") | |||
| parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") | |||
| args = parser.parse_args() | |||
| sql_query = """ | |||
| @@ -229,50 +234,54 @@ class WorkflowAverageAppInteractionStatistic(Resource): | |||
| GROUP BY date, c.created_by) sub | |||
| GROUP BY sub.date | |||
| """ | |||
| arg_dict = {'tz': account.timezone, 'app_id': app_model.id, 'triggered_from': WorkflowRunTriggeredFrom.APP_RUN.value} | |||
| arg_dict = { | |||
| "tz": account.timezone, | |||
| "app_id": app_model.id, | |||
| "triggered_from": WorkflowRunTriggeredFrom.APP_RUN.value, | |||
| } | |||
| timezone = pytz.timezone(account.timezone) | |||
| utc_timezone = pytz.utc | |||
| if args['start']: | |||
| start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M') | |||
| if args["start"]: | |||
| start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M") | |||
| start_datetime = start_datetime.replace(second=0) | |||
| start_datetime_timezone = timezone.localize(start_datetime) | |||
| start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) | |||
| sql_query = sql_query.replace('{{start}}', ' AND c.created_at >= :start') | |||
| arg_dict['start'] = start_datetime_utc | |||
| sql_query = sql_query.replace("{{start}}", " AND c.created_at >= :start") | |||
| arg_dict["start"] = start_datetime_utc | |||
| else: | |||
| sql_query = sql_query.replace('{{start}}', '') | |||
| sql_query = sql_query.replace("{{start}}", "") | |||
| if args['end']: | |||
| end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M') | |||
| if args["end"]: | |||
| end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M") | |||
| end_datetime = end_datetime.replace(second=0) | |||
| end_datetime_timezone = timezone.localize(end_datetime) | |||
| end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) | |||
| sql_query = sql_query.replace('{{end}}', ' and c.created_at < :end') | |||
| arg_dict['end'] = end_datetime_utc | |||
| sql_query = sql_query.replace("{{end}}", " and c.created_at < :end") | |||
| arg_dict["end"] = end_datetime_utc | |||
| else: | |||
| sql_query = sql_query.replace('{{end}}', '') | |||
| sql_query = sql_query.replace("{{end}}", "") | |||
| response_data = [] | |||
| with db.engine.begin() as conn: | |||
| rs = conn.execute(db.text(sql_query), arg_dict) | |||
| for i in rs: | |||
| response_data.append({ | |||
| 'date': str(i.date), | |||
| 'interactions': float(i.interactions.quantize(Decimal('0.01'))) | |||
| }) | |||
| return jsonify({ | |||
| 'data': response_data | |||
| }) | |||
| api.add_resource(WorkflowDailyRunsStatistic, '/apps/<uuid:app_id>/workflow/statistics/daily-conversations') | |||
| api.add_resource(WorkflowDailyTerminalsStatistic, '/apps/<uuid:app_id>/workflow/statistics/daily-terminals') | |||
| api.add_resource(WorkflowDailyTokenCostStatistic, '/apps/<uuid:app_id>/workflow/statistics/token-costs') | |||
| api.add_resource(WorkflowAverageAppInteractionStatistic, '/apps/<uuid:app_id>/workflow/statistics/average-app-interactions') | |||
| response_data.append( | |||
| {"date": str(i.date), "interactions": float(i.interactions.quantize(Decimal("0.01")))} | |||
| ) | |||
| return jsonify({"data": response_data}) | |||
| api.add_resource(WorkflowDailyRunsStatistic, "/apps/<uuid:app_id>/workflow/statistics/daily-conversations") | |||
| api.add_resource(WorkflowDailyTerminalsStatistic, "/apps/<uuid:app_id>/workflow/statistics/daily-terminals") | |||
| api.add_resource(WorkflowDailyTokenCostStatistic, "/apps/<uuid:app_id>/workflow/statistics/token-costs") | |||
| api.add_resource( | |||
| WorkflowAverageAppInteractionStatistic, "/apps/<uuid:app_id>/workflow/statistics/average-app-interactions" | |||
| ) | |||
| @@ -8,24 +8,23 @@ from libs.login import current_user | |||
| from models.model import App, AppMode | |||
| def get_app_model(view: Optional[Callable] = None, *, | |||
| mode: Union[AppMode, list[AppMode]] = None): | |||
| def get_app_model(view: Optional[Callable] = None, *, mode: Union[AppMode, list[AppMode]] = None): | |||
| def decorator(view_func): | |||
| @wraps(view_func) | |||
| def decorated_view(*args, **kwargs): | |||
| if not kwargs.get('app_id'): | |||
| raise ValueError('missing app_id in path parameters') | |||
| if not kwargs.get("app_id"): | |||
| raise ValueError("missing app_id in path parameters") | |||
| app_id = kwargs.get('app_id') | |||
| app_id = kwargs.get("app_id") | |||
| app_id = str(app_id) | |||
| del kwargs['app_id'] | |||
| del kwargs["app_id"] | |||
| app_model = db.session.query(App).filter( | |||
| App.id == app_id, | |||
| App.tenant_id == current_user.current_tenant_id, | |||
| App.status == 'normal' | |||
| ).first() | |||
| app_model = ( | |||
| db.session.query(App) | |||
| .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") | |||
| .first() | |||
| ) | |||
| if not app_model: | |||
| raise AppNotFoundError() | |||
| @@ -44,9 +43,10 @@ def get_app_model(view: Optional[Callable] = None, *, | |||
| mode_values = {m.value for m in modes} | |||
| raise AppNotFoundError(f"App mode is not in the supported list: {mode_values}") | |||
| kwargs['app_model'] = app_model | |||
| kwargs["app_model"] = app_model | |||
| return view_func(*args, **kwargs) | |||
| return decorated_view | |||
| if view is None: | |||
| @@ -17,60 +17,61 @@ from services.account_service import RegisterService | |||
| class ActivateCheckApi(Resource): | |||
| def get(self): | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('workspace_id', type=str, required=False, nullable=True, location='args') | |||
| parser.add_argument('email', type=email, required=False, nullable=True, location='args') | |||
| parser.add_argument('token', type=str, required=True, nullable=False, location='args') | |||
| parser.add_argument("workspace_id", type=str, required=False, nullable=True, location="args") | |||
| parser.add_argument("email", type=email, required=False, nullable=True, location="args") | |||
| parser.add_argument("token", type=str, required=True, nullable=False, location="args") | |||
| args = parser.parse_args() | |||
| workspaceId = args['workspace_id'] | |||
| reg_email = args['email'] | |||
| token = args['token'] | |||
| workspaceId = args["workspace_id"] | |||
| reg_email = args["email"] | |||
| token = args["token"] | |||
| invitation = RegisterService.get_invitation_if_token_valid(workspaceId, reg_email, token) | |||
| return {'is_valid': invitation is not None, 'workspace_name': invitation['tenant'].name if invitation else None} | |||
| return {"is_valid": invitation is not None, "workspace_name": invitation["tenant"].name if invitation else None} | |||
| class ActivateApi(Resource): | |||
| def post(self): | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('workspace_id', type=str, required=False, nullable=True, location='json') | |||
| parser.add_argument('email', type=email, required=False, nullable=True, location='json') | |||
| parser.add_argument('token', type=str, required=True, nullable=False, location='json') | |||
| parser.add_argument('name', type=str_len(30), required=True, nullable=False, location='json') | |||
| parser.add_argument('password', type=valid_password, required=True, nullable=False, location='json') | |||
| parser.add_argument('interface_language', type=supported_language, required=True, nullable=False, | |||
| location='json') | |||
| parser.add_argument('timezone', type=timezone, required=True, nullable=False, location='json') | |||
| parser.add_argument("workspace_id", type=str, required=False, nullable=True, location="json") | |||
| parser.add_argument("email", type=email, required=False, nullable=True, location="json") | |||
| parser.add_argument("token", type=str, required=True, nullable=False, location="json") | |||
| parser.add_argument("name", type=str_len(30), required=True, nullable=False, location="json") | |||
| parser.add_argument("password", type=valid_password, required=True, nullable=False, location="json") | |||
| parser.add_argument( | |||
| "interface_language", type=supported_language, required=True, nullable=False, location="json" | |||
| ) | |||
| parser.add_argument("timezone", type=timezone, required=True, nullable=False, location="json") | |||
| args = parser.parse_args() | |||
| invitation = RegisterService.get_invitation_if_token_valid(args['workspace_id'], args['email'], args['token']) | |||
| invitation = RegisterService.get_invitation_if_token_valid(args["workspace_id"], args["email"], args["token"]) | |||
| if invitation is None: | |||
| raise AlreadyActivateError() | |||
| RegisterService.revoke_token(args['workspace_id'], args['email'], args['token']) | |||
| RegisterService.revoke_token(args["workspace_id"], args["email"], args["token"]) | |||
| account = invitation['account'] | |||
| account.name = args['name'] | |||
| account = invitation["account"] | |||
| account.name = args["name"] | |||
| # generate password salt | |||
| salt = secrets.token_bytes(16) | |||
| base64_salt = base64.b64encode(salt).decode() | |||
| # encrypt password with salt | |||
| password_hashed = hash_password(args['password'], salt) | |||
| password_hashed = hash_password(args["password"], salt) | |||
| base64_password_hashed = base64.b64encode(password_hashed).decode() | |||
| account.password = base64_password_hashed | |||
| account.password_salt = base64_salt | |||
| account.interface_language = args['interface_language'] | |||
| account.timezone = args['timezone'] | |||
| account.interface_theme = 'light' | |||
| account.interface_language = args["interface_language"] | |||
| account.timezone = args["timezone"] | |||
| account.interface_theme = "light" | |||
| account.status = AccountStatus.ACTIVE.value | |||
| account.initialized_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) | |||
| db.session.commit() | |||
| return {'result': 'success'} | |||
| return {"result": "success"} | |||
| api.add_resource(ActivateCheckApi, '/activate/check') | |||
| api.add_resource(ActivateApi, '/activate') | |||
| api.add_resource(ActivateCheckApi, "/activate/check") | |||
| api.add_resource(ActivateApi, "/activate") | |||
| @@ -19,18 +19,19 @@ class ApiKeyAuthDataSource(Resource): | |||
| data_source_api_key_bindings = ApiKeyAuthService.get_provider_auth_list(current_user.current_tenant_id) | |||
| if data_source_api_key_bindings: | |||
| return { | |||
| 'sources': [{ | |||
| 'id': data_source_api_key_binding.id, | |||
| 'category': data_source_api_key_binding.category, | |||
| 'provider': data_source_api_key_binding.provider, | |||
| 'disabled': data_source_api_key_binding.disabled, | |||
| 'created_at': int(data_source_api_key_binding.created_at.timestamp()), | |||
| 'updated_at': int(data_source_api_key_binding.updated_at.timestamp()), | |||
| } | |||
| for data_source_api_key_binding in | |||
| data_source_api_key_bindings] | |||
| "sources": [ | |||
| { | |||
| "id": data_source_api_key_binding.id, | |||
| "category": data_source_api_key_binding.category, | |||
| "provider": data_source_api_key_binding.provider, | |||
| "disabled": data_source_api_key_binding.disabled, | |||
| "created_at": int(data_source_api_key_binding.created_at.timestamp()), | |||
| "updated_at": int(data_source_api_key_binding.updated_at.timestamp()), | |||
| } | |||
| for data_source_api_key_binding in data_source_api_key_bindings | |||
| ] | |||
| } | |||
| return {'sources': []} | |||
| return {"sources": []} | |||
| class ApiKeyAuthDataSourceBinding(Resource): | |||
| @@ -42,16 +43,16 @@ class ApiKeyAuthDataSourceBinding(Resource): | |||
| if not current_user.is_admin_or_owner: | |||
| raise Forbidden() | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('category', type=str, required=True, nullable=False, location='json') | |||
| parser.add_argument('provider', type=str, required=True, nullable=False, location='json') | |||
| parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json') | |||
| parser.add_argument("category", type=str, required=True, nullable=False, location="json") | |||
| parser.add_argument("provider", type=str, required=True, nullable=False, location="json") | |||
| parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") | |||
| args = parser.parse_args() | |||
| ApiKeyAuthService.validate_api_key_auth_args(args) | |||
| try: | |||
| ApiKeyAuthService.create_provider_auth(current_user.current_tenant_id, args) | |||
| except Exception as e: | |||
| raise ApiKeyAuthFailedError(str(e)) | |||
| return {'result': 'success'}, 200 | |||
| return {"result": "success"}, 200 | |||
| class ApiKeyAuthDataSourceBindingDelete(Resource): | |||
| @@ -65,9 +66,9 @@ class ApiKeyAuthDataSourceBindingDelete(Resource): | |||
| ApiKeyAuthService.delete_provider_auth(current_user.current_tenant_id, binding_id) | |||
| return {'result': 'success'}, 200 | |||
| return {"result": "success"}, 200 | |||
| api.add_resource(ApiKeyAuthDataSource, '/api-key-auth/data-source') | |||
| api.add_resource(ApiKeyAuthDataSourceBinding, '/api-key-auth/data-source/binding') | |||
| api.add_resource(ApiKeyAuthDataSourceBindingDelete, '/api-key-auth/data-source/<uuid:binding_id>') | |||
| api.add_resource(ApiKeyAuthDataSource, "/api-key-auth/data-source") | |||
| api.add_resource(ApiKeyAuthDataSourceBinding, "/api-key-auth/data-source/binding") | |||
| api.add_resource(ApiKeyAuthDataSourceBindingDelete, "/api-key-auth/data-source/<uuid:binding_id>") | |||
| @@ -17,13 +17,13 @@ from ..wraps import account_initialization_required | |||
| def get_oauth_providers(): | |||
| with current_app.app_context(): | |||
| notion_oauth = NotionOAuth(client_id=dify_config.NOTION_CLIENT_ID, | |||
| client_secret=dify_config.NOTION_CLIENT_SECRET, | |||
| redirect_uri=dify_config.CONSOLE_API_URL + '/console/api/oauth/data-source/callback/notion') | |||
| notion_oauth = NotionOAuth( | |||
| client_id=dify_config.NOTION_CLIENT_ID, | |||
| client_secret=dify_config.NOTION_CLIENT_SECRET, | |||
| redirect_uri=dify_config.CONSOLE_API_URL + "/console/api/oauth/data-source/callback/notion", | |||
| ) | |||
| OAUTH_PROVIDERS = { | |||
| 'notion': notion_oauth | |||
| } | |||
| OAUTH_PROVIDERS = {"notion": notion_oauth} | |||
| return OAUTH_PROVIDERS | |||
| @@ -37,18 +37,16 @@ class OAuthDataSource(Resource): | |||
| oauth_provider = OAUTH_DATASOURCE_PROVIDERS.get(provider) | |||
| print(vars(oauth_provider)) | |||
| if not oauth_provider: | |||
| return {'error': 'Invalid provider'}, 400 | |||
| if dify_config.NOTION_INTEGRATION_TYPE == 'internal': | |||
| return {"error": "Invalid provider"}, 400 | |||
| if dify_config.NOTION_INTEGRATION_TYPE == "internal": | |||
| internal_secret = dify_config.NOTION_INTERNAL_SECRET | |||
| if not internal_secret: | |||
| return {'error': 'Internal secret is not set'}, | |||
| return ({"error": "Internal secret is not set"},) | |||
| oauth_provider.save_internal_access_token(internal_secret) | |||
| return { 'data': '' } | |||
| return {"data": ""} | |||
| else: | |||
| auth_url = oauth_provider.get_authorization_url() | |||
| return { 'data': auth_url }, 200 | |||
| return {"data": auth_url}, 200 | |||
| class OAuthDataSourceCallback(Resource): | |||
| @@ -57,18 +55,18 @@ class OAuthDataSourceCallback(Resource): | |||
| with current_app.app_context(): | |||
| oauth_provider = OAUTH_DATASOURCE_PROVIDERS.get(provider) | |||
| if not oauth_provider: | |||
| return {'error': 'Invalid provider'}, 400 | |||
| if 'code' in request.args: | |||
| code = request.args.get('code') | |||
| return {"error": "Invalid provider"}, 400 | |||
| if "code" in request.args: | |||
| code = request.args.get("code") | |||
| return redirect(f'{dify_config.CONSOLE_WEB_URL}?type=notion&code={code}') | |||
| elif 'error' in request.args: | |||
| error = request.args.get('error') | |||
| return redirect(f"{dify_config.CONSOLE_WEB_URL}?type=notion&code={code}") | |||
| elif "error" in request.args: | |||
| error = request.args.get("error") | |||
| return redirect(f'{dify_config.CONSOLE_WEB_URL}?type=notion&error={error}') | |||
| return redirect(f"{dify_config.CONSOLE_WEB_URL}?type=notion&error={error}") | |||
| else: | |||
| return redirect(f'{dify_config.CONSOLE_WEB_URL}?type=notion&error=Access denied') | |||
| return redirect(f"{dify_config.CONSOLE_WEB_URL}?type=notion&error=Access denied") | |||
| class OAuthDataSourceBinding(Resource): | |||
| def get(self, provider: str): | |||
| @@ -76,17 +74,18 @@ class OAuthDataSourceBinding(Resource): | |||
| with current_app.app_context(): | |||
| oauth_provider = OAUTH_DATASOURCE_PROVIDERS.get(provider) | |||
| if not oauth_provider: | |||
| return {'error': 'Invalid provider'}, 400 | |||
| if 'code' in request.args: | |||
| code = request.args.get('code') | |||
| return {"error": "Invalid provider"}, 400 | |||
| if "code" in request.args: | |||
| code = request.args.get("code") | |||
| try: | |||
| oauth_provider.get_access_token(code) | |||
| except requests.exceptions.HTTPError as e: | |||
| logging.exception( | |||
| f"An error occurred during the OAuthCallback process with {provider}: {e.response.text}") | |||
| return {'error': 'OAuth data source process failed'}, 400 | |||
| f"An error occurred during the OAuthCallback process with {provider}: {e.response.text}" | |||
| ) | |||
| return {"error": "OAuth data source process failed"}, 400 | |||
| return {'result': 'success'}, 200 | |||
| return {"result": "success"}, 200 | |||
| class OAuthDataSourceSync(Resource): | |||
| @@ -100,18 +99,17 @@ class OAuthDataSourceSync(Resource): | |||
| with current_app.app_context(): | |||
| oauth_provider = OAUTH_DATASOURCE_PROVIDERS.get(provider) | |||
| if not oauth_provider: | |||
| return {'error': 'Invalid provider'}, 400 | |||
| return {"error": "Invalid provider"}, 400 | |||
| try: | |||
| oauth_provider.sync_data_source(binding_id) | |||
| except requests.exceptions.HTTPError as e: | |||
| logging.exception( | |||
| f"An error occurred during the OAuthCallback process with {provider}: {e.response.text}") | |||
| return {'error': 'OAuth data source process failed'}, 400 | |||
| logging.exception(f"An error occurred during the OAuthCallback process with {provider}: {e.response.text}") | |||
| return {"error": "OAuth data source process failed"}, 400 | |||
| return {'result': 'success'}, 200 | |||
| return {"result": "success"}, 200 | |||
| api.add_resource(OAuthDataSource, '/oauth/data-source/<string:provider>') | |||
| api.add_resource(OAuthDataSourceCallback, '/oauth/data-source/callback/<string:provider>') | |||
| api.add_resource(OAuthDataSourceBinding, '/oauth/data-source/binding/<string:provider>') | |||
| api.add_resource(OAuthDataSourceSync, '/oauth/data-source/<string:provider>/<uuid:binding_id>/sync') | |||
| api.add_resource(OAuthDataSource, "/oauth/data-source/<string:provider>") | |||
| api.add_resource(OAuthDataSourceCallback, "/oauth/data-source/callback/<string:provider>") | |||
| api.add_resource(OAuthDataSourceBinding, "/oauth/data-source/binding/<string:provider>") | |||
| api.add_resource(OAuthDataSourceSync, "/oauth/data-source/<string:provider>/<uuid:binding_id>/sync") | |||
| @@ -2,31 +2,30 @@ from libs.exception import BaseHTTPException | |||
| class ApiKeyAuthFailedError(BaseHTTPException): | |||
| error_code = 'auth_failed' | |||
| error_code = "auth_failed" | |||
| description = "{message}" | |||
| code = 500 | |||
| class InvalidEmailError(BaseHTTPException): | |||
| error_code = 'invalid_email' | |||
| error_code = "invalid_email" | |||
| description = "The email address is not valid." | |||
| code = 400 | |||
| class PasswordMismatchError(BaseHTTPException): | |||
| error_code = 'password_mismatch' | |||
| error_code = "password_mismatch" | |||
| description = "The passwords do not match." | |||
| code = 400 | |||
| class InvalidTokenError(BaseHTTPException): | |||
| error_code = 'invalid_or_expired_token' | |||
| error_code = "invalid_or_expired_token" | |||
| description = "The token is invalid or has expired." | |||
| code = 400 | |||
| class PasswordResetRateLimitExceededError(BaseHTTPException): | |||
| error_code = 'password_reset_rate_limit_exceeded' | |||
| error_code = "password_reset_rate_limit_exceeded" | |||
| description = "Password reset rate limit exceeded. Try again later." | |||
| code = 429 | |||
| @@ -21,14 +21,13 @@ from services.errors.account import RateLimitExceededError | |||
| class ForgotPasswordSendEmailApi(Resource): | |||
| @setup_required | |||
| def post(self): | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('email', type=str, required=True, location='json') | |||
| parser.add_argument("email", type=str, required=True, location="json") | |||
| args = parser.parse_args() | |||
| email = args['email'] | |||
| email = args["email"] | |||
| if not email_validate(email): | |||
| raise InvalidEmailError() | |||
| @@ -49,38 +48,36 @@ class ForgotPasswordSendEmailApi(Resource): | |||
| class ForgotPasswordCheckApi(Resource): | |||
| @setup_required | |||
| def post(self): | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('token', type=str, required=True, nullable=False, location='json') | |||
| parser.add_argument("token", type=str, required=True, nullable=False, location="json") | |||
| args = parser.parse_args() | |||
| token = args['token'] | |||
| token = args["token"] | |||
| reset_data = AccountService.get_reset_password_data(token) | |||
| if reset_data is None: | |||
| return {'is_valid': False, 'email': None} | |||
| return {'is_valid': True, 'email': reset_data.get('email')} | |||
| return {"is_valid": False, "email": None} | |||
| return {"is_valid": True, "email": reset_data.get("email")} | |||
| class ForgotPasswordResetApi(Resource): | |||
| @setup_required | |||
| def post(self): | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('token', type=str, required=True, nullable=False, location='json') | |||
| parser.add_argument('new_password', type=valid_password, required=True, nullable=False, location='json') | |||
| parser.add_argument('password_confirm', type=valid_password, required=True, nullable=False, location='json') | |||
| parser.add_argument("token", type=str, required=True, nullable=False, location="json") | |||
| parser.add_argument("new_password", type=valid_password, required=True, nullable=False, location="json") | |||
| parser.add_argument("password_confirm", type=valid_password, required=True, nullable=False, location="json") | |||
| args = parser.parse_args() | |||
| new_password = args['new_password'] | |||
| password_confirm = args['password_confirm'] | |||
| new_password = args["new_password"] | |||
| password_confirm = args["password_confirm"] | |||
| if str(new_password).strip() != str(password_confirm).strip(): | |||
| raise PasswordMismatchError() | |||
| token = args['token'] | |||
| token = args["token"] | |||
| reset_data = AccountService.get_reset_password_data(token) | |||
| if reset_data is None: | |||
| @@ -94,14 +91,14 @@ class ForgotPasswordResetApi(Resource): | |||
| password_hashed = hash_password(new_password, salt) | |||
| base64_password_hashed = base64.b64encode(password_hashed).decode() | |||
| account = Account.query.filter_by(email=reset_data.get('email')).first() | |||
| account = Account.query.filter_by(email=reset_data.get("email")).first() | |||
| account.password = base64_password_hashed | |||
| account.password_salt = base64_salt | |||
| db.session.commit() | |||
| return {'result': 'success'} | |||
| return {"result": "success"} | |||
| api.add_resource(ForgotPasswordSendEmailApi, '/forgot-password') | |||
| api.add_resource(ForgotPasswordCheckApi, '/forgot-password/validity') | |||
| api.add_resource(ForgotPasswordResetApi, '/forgot-password/resets') | |||
| api.add_resource(ForgotPasswordSendEmailApi, "/forgot-password") | |||
| api.add_resource(ForgotPasswordCheckApi, "/forgot-password/validity") | |||
| api.add_resource(ForgotPasswordResetApi, "/forgot-password/resets") | |||
| @@ -20,37 +20,39 @@ class LoginApi(Resource): | |||
| def post(self): | |||
| """Authenticate user and login.""" | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('email', type=email, required=True, location='json') | |||
| parser.add_argument('password', type=valid_password, required=True, location='json') | |||
| parser.add_argument('remember_me', type=bool, required=False, default=False, location='json') | |||
| parser.add_argument("email", type=email, required=True, location="json") | |||
| parser.add_argument("password", type=valid_password, required=True, location="json") | |||
| parser.add_argument("remember_me", type=bool, required=False, default=False, location="json") | |||
| args = parser.parse_args() | |||
| # todo: Verify the recaptcha | |||
| try: | |||
| account = AccountService.authenticate(args['email'], args['password']) | |||
| account = AccountService.authenticate(args["email"], args["password"]) | |||
| except services.errors.account.AccountLoginError as e: | |||
| return {'code': 'unauthorized', 'message': str(e)}, 401 | |||
| return {"code": "unauthorized", "message": str(e)}, 401 | |||
| # SELF_HOSTED only have one workspace | |||
| tenants = TenantService.get_join_tenants(account) | |||
| if len(tenants) == 0: | |||
| return {'result': 'fail', 'data': 'workspace not found, please contact system admin to invite you to join in a workspace'} | |||
| return { | |||
| "result": "fail", | |||
| "data": "workspace not found, please contact system admin to invite you to join in a workspace", | |||
| } | |||
| token = AccountService.login(account, ip_address=get_remote_ip(request)) | |||
| return {'result': 'success', 'data': token} | |||
| return {"result": "success", "data": token} | |||
| class LogoutApi(Resource): | |||
| @setup_required | |||
| def get(self): | |||
| account = cast(Account, flask_login.current_user) | |||
| token = request.headers.get('Authorization', '').split(' ')[1] | |||
| token = request.headers.get("Authorization", "").split(" ")[1] | |||
| AccountService.logout(account=account, token=token) | |||
| flask_login.logout_user() | |||
| return {'result': 'success'} | |||
| return {"result": "success"} | |||
| class ResetPasswordApi(Resource): | |||
| @@ -80,11 +82,11 @@ class ResetPasswordApi(Resource): | |||
| # 'subject': 'Reset your Dify password', | |||
| # 'html': """ | |||
| # <p>Dear User,</p> | |||
| # <p>The Dify team has generated a new password for you, details as follows:</p> | |||
| # <p>The Dify team has generated a new password for you, details as follows:</p> | |||
| # <p><strong>{new_password}</strong></p> | |||
| # <p>Please change your password to log in as soon as possible.</p> | |||
| # <p>Regards,</p> | |||
| # <p>The Dify Team</p> | |||
| # <p>The Dify Team</p> | |||
| # """ | |||
| # } | |||
| @@ -101,8 +103,8 @@ class ResetPasswordApi(Resource): | |||
| # # handle error | |||
| # pass | |||
| return {'result': 'success'} | |||
| return {"result": "success"} | |||
| api.add_resource(LoginApi, '/login') | |||
| api.add_resource(LogoutApi, '/logout') | |||
| api.add_resource(LoginApi, "/login") | |||
| api.add_resource(LogoutApi, "/logout") | |||
| @@ -25,7 +25,7 @@ def get_oauth_providers(): | |||
| github_oauth = GitHubOAuth( | |||
| client_id=dify_config.GITHUB_CLIENT_ID, | |||
| client_secret=dify_config.GITHUB_CLIENT_SECRET, | |||
| redirect_uri=dify_config.CONSOLE_API_URL + '/console/api/oauth/authorize/github', | |||
| redirect_uri=dify_config.CONSOLE_API_URL + "/console/api/oauth/authorize/github", | |||
| ) | |||
| if not dify_config.GOOGLE_CLIENT_ID or not dify_config.GOOGLE_CLIENT_SECRET: | |||
| google_oauth = None | |||
| @@ -33,10 +33,10 @@ def get_oauth_providers(): | |||
| google_oauth = GoogleOAuth( | |||
| client_id=dify_config.GOOGLE_CLIENT_ID, | |||
| client_secret=dify_config.GOOGLE_CLIENT_SECRET, | |||
| redirect_uri=dify_config.CONSOLE_API_URL + '/console/api/oauth/authorize/google', | |||
| redirect_uri=dify_config.CONSOLE_API_URL + "/console/api/oauth/authorize/google", | |||
| ) | |||
| OAUTH_PROVIDERS = {'github': github_oauth, 'google': google_oauth} | |||
| OAUTH_PROVIDERS = {"github": github_oauth, "google": google_oauth} | |||
| return OAUTH_PROVIDERS | |||
| @@ -47,7 +47,7 @@ class OAuthLogin(Resource): | |||
| oauth_provider = OAUTH_PROVIDERS.get(provider) | |||
| print(vars(oauth_provider)) | |||
| if not oauth_provider: | |||
| return {'error': 'Invalid provider'}, 400 | |||
| return {"error": "Invalid provider"}, 400 | |||
| auth_url = oauth_provider.get_authorization_url() | |||
| return redirect(auth_url) | |||
| @@ -59,20 +59,20 @@ class OAuthCallback(Resource): | |||
| with current_app.app_context(): | |||
| oauth_provider = OAUTH_PROVIDERS.get(provider) | |||
| if not oauth_provider: | |||
| return {'error': 'Invalid provider'}, 400 | |||
| return {"error": "Invalid provider"}, 400 | |||
| code = request.args.get('code') | |||
| code = request.args.get("code") | |||
| try: | |||
| token = oauth_provider.get_access_token(code) | |||
| user_info = oauth_provider.get_user_info(token) | |||
| except requests.exceptions.HTTPError as e: | |||
| logging.exception(f'An error occurred during the OAuth process with {provider}: {e.response.text}') | |||
| return {'error': 'OAuth process failed'}, 400 | |||
| logging.exception(f"An error occurred during the OAuth process with {provider}: {e.response.text}") | |||
| return {"error": "OAuth process failed"}, 400 | |||
| account = _generate_account(provider, user_info) | |||
| # Check account status | |||
| if account.status == AccountStatus.BANNED.value or account.status == AccountStatus.CLOSED.value: | |||
| return {'error': 'Account is banned or closed.'}, 403 | |||
| return {"error": "Account is banned or closed."}, 403 | |||
| if account.status == AccountStatus.PENDING.value: | |||
| account.status = AccountStatus.ACTIVE.value | |||
| @@ -83,7 +83,7 @@ class OAuthCallback(Resource): | |||
| token = AccountService.login(account, ip_address=get_remote_ip(request)) | |||
| return redirect(f'{dify_config.CONSOLE_WEB_URL}?console_token={token}') | |||
| return redirect(f"{dify_config.CONSOLE_WEB_URL}?console_token={token}") | |||
| def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) -> Optional[Account]: | |||
| @@ -101,7 +101,7 @@ def _generate_account(provider: str, user_info: OAuthUserInfo): | |||
| if not account: | |||
| # Create account | |||
| account_name = user_info.name if user_info.name else 'Dify' | |||
| account_name = user_info.name if user_info.name else "Dify" | |||
| account = RegisterService.register( | |||
| email=user_info.email, name=account_name, password=None, open_id=user_info.id, provider=provider | |||
| ) | |||
| @@ -121,5 +121,5 @@ def _generate_account(provider: str, user_info: OAuthUserInfo): | |||
| return account | |||
| api.add_resource(OAuthLogin, '/oauth/login/<provider>') | |||
| api.add_resource(OAuthCallback, '/oauth/authorize/<provider>') | |||
| api.add_resource(OAuthLogin, "/oauth/login/<provider>") | |||
| api.add_resource(OAuthCallback, "/oauth/authorize/<provider>") | |||
| @@ -9,28 +9,24 @@ from services.billing_service import BillingService | |||
| class Subscription(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @only_edition_cloud | |||
| def get(self): | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('plan', type=str, required=True, location='args', choices=['professional', 'team']) | |||
| parser.add_argument('interval', type=str, required=True, location='args', choices=['month', 'year']) | |||
| parser.add_argument("plan", type=str, required=True, location="args", choices=["professional", "team"]) | |||
| parser.add_argument("interval", type=str, required=True, location="args", choices=["month", "year"]) | |||
| args = parser.parse_args() | |||
| BillingService.is_tenant_owner_or_admin(current_user) | |||
| return BillingService.get_subscription(args['plan'], | |||
| args['interval'], | |||
| current_user.email, | |||
| current_user.current_tenant_id) | |||
| return BillingService.get_subscription( | |||
| args["plan"], args["interval"], current_user.email, current_user.current_tenant_id | |||
| ) | |||
| class Invoices(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @@ -40,5 +36,5 @@ class Invoices(Resource): | |||
| return BillingService.get_invoices(current_user.email, current_user.current_tenant_id) | |||
| api.add_resource(Subscription, '/billing/subscription') | |||
| api.add_resource(Invoices, '/billing/invoices') | |||
| api.add_resource(Subscription, "/billing/subscription") | |||
| api.add_resource(Invoices, "/billing/invoices") | |||
| @@ -22,19 +22,22 @@ from tasks.document_indexing_sync_task import document_indexing_sync_task | |||
| class DataSourceApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @marshal_with(integrate_list_fields) | |||
| def get(self): | |||
| # get workspace data source integrates | |||
| data_source_integrates = db.session.query(DataSourceOauthBinding).filter( | |||
| DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, | |||
| DataSourceOauthBinding.disabled == False | |||
| ).all() | |||
| data_source_integrates = ( | |||
| db.session.query(DataSourceOauthBinding) | |||
| .filter( | |||
| DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, | |||
| DataSourceOauthBinding.disabled == False, | |||
| ) | |||
| .all() | |||
| ) | |||
| base_url = request.url_root.rstrip('/') | |||
| base_url = request.url_root.rstrip("/") | |||
| data_source_oauth_base_path = "/console/api/oauth/data-source" | |||
| providers = ["notion"] | |||
| @@ -44,26 +47,30 @@ class DataSourceApi(Resource): | |||
| existing_integrates = filter(lambda item: item.provider == provider, data_source_integrates) | |||
| if existing_integrates: | |||
| for existing_integrate in list(existing_integrates): | |||
| integrate_data.append({ | |||
| 'id': existing_integrate.id, | |||
| 'provider': provider, | |||
| 'created_at': existing_integrate.created_at, | |||
| 'is_bound': True, | |||
| 'disabled': existing_integrate.disabled, | |||
| 'source_info': existing_integrate.source_info, | |||
| 'link': f'{base_url}{data_source_oauth_base_path}/{provider}' | |||
| }) | |||
| integrate_data.append( | |||
| { | |||
| "id": existing_integrate.id, | |||
| "provider": provider, | |||
| "created_at": existing_integrate.created_at, | |||
| "is_bound": True, | |||
| "disabled": existing_integrate.disabled, | |||
| "source_info": existing_integrate.source_info, | |||
| "link": f"{base_url}{data_source_oauth_base_path}/{provider}", | |||
| } | |||
| ) | |||
| else: | |||
| integrate_data.append({ | |||
| 'id': None, | |||
| 'provider': provider, | |||
| 'created_at': None, | |||
| 'source_info': None, | |||
| 'is_bound': False, | |||
| 'disabled': None, | |||
| 'link': f'{base_url}{data_source_oauth_base_path}/{provider}' | |||
| }) | |||
| return {'data': integrate_data}, 200 | |||
| integrate_data.append( | |||
| { | |||
| "id": None, | |||
| "provider": provider, | |||
| "created_at": None, | |||
| "source_info": None, | |||
| "is_bound": False, | |||
| "disabled": None, | |||
| "link": f"{base_url}{data_source_oauth_base_path}/{provider}", | |||
| } | |||
| ) | |||
| return {"data": integrate_data}, 200 | |||
| @setup_required | |||
| @login_required | |||
| @@ -71,92 +78,82 @@ class DataSourceApi(Resource): | |||
| def patch(self, binding_id, action): | |||
| binding_id = str(binding_id) | |||
| action = str(action) | |||
| data_source_binding = DataSourceOauthBinding.query.filter_by( | |||
| id=binding_id | |||
| ).first() | |||
| data_source_binding = DataSourceOauthBinding.query.filter_by(id=binding_id).first() | |||
| if data_source_binding is None: | |||
| raise NotFound('Data source binding not found.') | |||
| raise NotFound("Data source binding not found.") | |||
| # enable binding | |||
| if action == 'enable': | |||
| if action == "enable": | |||
| if data_source_binding.disabled: | |||
| data_source_binding.disabled = False | |||
| data_source_binding.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) | |||
| db.session.add(data_source_binding) | |||
| db.session.commit() | |||
| else: | |||
| raise ValueError('Data source is not disabled.') | |||
| raise ValueError("Data source is not disabled.") | |||
| # disable binding | |||
| if action == 'disable': | |||
| if action == "disable": | |||
| if not data_source_binding.disabled: | |||
| data_source_binding.disabled = True | |||
| data_source_binding.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) | |||
| db.session.add(data_source_binding) | |||
| db.session.commit() | |||
| else: | |||
| raise ValueError('Data source is disabled.') | |||
| return {'result': 'success'}, 200 | |||
| raise ValueError("Data source is disabled.") | |||
| return {"result": "success"}, 200 | |||
| class DataSourceNotionListApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @marshal_with(integrate_notion_info_list_fields) | |||
| def get(self): | |||
| dataset_id = request.args.get('dataset_id', default=None, type=str) | |||
| dataset_id = request.args.get("dataset_id", default=None, type=str) | |||
| exist_page_ids = [] | |||
| # import notion in the exist dataset | |||
| if dataset_id: | |||
| dataset = DatasetService.get_dataset(dataset_id) | |||
| if not dataset: | |||
| raise NotFound('Dataset not found.') | |||
| if dataset.data_source_type != 'notion_import': | |||
| raise ValueError('Dataset is not notion type.') | |||
| raise NotFound("Dataset not found.") | |||
| if dataset.data_source_type != "notion_import": | |||
| raise ValueError("Dataset is not notion type.") | |||
| documents = Document.query.filter_by( | |||
| dataset_id=dataset_id, | |||
| tenant_id=current_user.current_tenant_id, | |||
| data_source_type='notion_import', | |||
| enabled=True | |||
| data_source_type="notion_import", | |||
| enabled=True, | |||
| ).all() | |||
| if documents: | |||
| for document in documents: | |||
| data_source_info = json.loads(document.data_source_info) | |||
| exist_page_ids.append(data_source_info['notion_page_id']) | |||
| exist_page_ids.append(data_source_info["notion_page_id"]) | |||
| # get all authorized pages | |||
| data_source_bindings = DataSourceOauthBinding.query.filter_by( | |||
| tenant_id=current_user.current_tenant_id, | |||
| provider='notion', | |||
| disabled=False | |||
| tenant_id=current_user.current_tenant_id, provider="notion", disabled=False | |||
| ).all() | |||
| if not data_source_bindings: | |||
| return { | |||
| 'notion_info': [] | |||
| }, 200 | |||
| return {"notion_info": []}, 200 | |||
| pre_import_info_list = [] | |||
| for data_source_binding in data_source_bindings: | |||
| source_info = data_source_binding.source_info | |||
| pages = source_info['pages'] | |||
| pages = source_info["pages"] | |||
| # Filter out already bound pages | |||
| for page in pages: | |||
| if page['page_id'] in exist_page_ids: | |||
| page['is_bound'] = True | |||
| if page["page_id"] in exist_page_ids: | |||
| page["is_bound"] = True | |||
| else: | |||
| page['is_bound'] = False | |||
| page["is_bound"] = False | |||
| pre_import_info = { | |||
| 'workspace_name': source_info['workspace_name'], | |||
| 'workspace_icon': source_info['workspace_icon'], | |||
| 'workspace_id': source_info['workspace_id'], | |||
| 'pages': pages, | |||
| "workspace_name": source_info["workspace_name"], | |||
| "workspace_icon": source_info["workspace_icon"], | |||
| "workspace_id": source_info["workspace_id"], | |||
| "pages": pages, | |||
| } | |||
| pre_import_info_list.append(pre_import_info) | |||
| return { | |||
| 'notion_info': pre_import_info_list | |||
| }, 200 | |||
| return {"notion_info": pre_import_info_list}, 200 | |||
| class DataSourceNotionApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @@ -166,64 +163,67 @@ class DataSourceNotionApi(Resource): | |||
| data_source_binding = DataSourceOauthBinding.query.filter( | |||
| db.and_( | |||
| DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, | |||
| DataSourceOauthBinding.provider == 'notion', | |||
| DataSourceOauthBinding.provider == "notion", | |||
| DataSourceOauthBinding.disabled == False, | |||
| DataSourceOauthBinding.source_info['workspace_id'] == f'"{workspace_id}"' | |||
| DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"', | |||
| ) | |||
| ).first() | |||
| if not data_source_binding: | |||
| raise NotFound('Data source binding not found.') | |||
| raise NotFound("Data source binding not found.") | |||
| extractor = NotionExtractor( | |||
| notion_workspace_id=workspace_id, | |||
| notion_obj_id=page_id, | |||
| notion_page_type=page_type, | |||
| notion_access_token=data_source_binding.access_token, | |||
| tenant_id=current_user.current_tenant_id | |||
| tenant_id=current_user.current_tenant_id, | |||
| ) | |||
| text_docs = extractor.extract() | |||
| return { | |||
| 'content': "\n".join([doc.page_content for doc in text_docs]) | |||
| }, 200 | |||
| return {"content": "\n".join([doc.page_content for doc in text_docs])}, 200 | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| def post(self): | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('notion_info_list', type=list, required=True, nullable=True, location='json') | |||
| parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json') | |||
| parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json') | |||
| parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False, location='json') | |||
| parser.add_argument("notion_info_list", type=list, required=True, nullable=True, location="json") | |||
| parser.add_argument("process_rule", type=dict, required=True, nullable=True, location="json") | |||
| parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json") | |||
| parser.add_argument( | |||
| "doc_language", type=str, default="English", required=False, nullable=False, location="json" | |||
| ) | |||
| args = parser.parse_args() | |||
| # validate args | |||
| DocumentService.estimate_args_validate(args) | |||
| notion_info_list = args['notion_info_list'] | |||
| notion_info_list = args["notion_info_list"] | |||
| extract_settings = [] | |||
| for notion_info in notion_info_list: | |||
| workspace_id = notion_info['workspace_id'] | |||
| for page in notion_info['pages']: | |||
| workspace_id = notion_info["workspace_id"] | |||
| for page in notion_info["pages"]: | |||
| extract_setting = ExtractSetting( | |||
| datasource_type="notion_import", | |||
| notion_info={ | |||
| "notion_workspace_id": workspace_id, | |||
| "notion_obj_id": page['page_id'], | |||
| "notion_page_type": page['type'], | |||
| "tenant_id": current_user.current_tenant_id | |||
| "notion_obj_id": page["page_id"], | |||
| "notion_page_type": page["type"], | |||
| "tenant_id": current_user.current_tenant_id, | |||
| }, | |||
| document_model=args['doc_form'] | |||
| document_model=args["doc_form"], | |||
| ) | |||
| extract_settings.append(extract_setting) | |||
| indexing_runner = IndexingRunner() | |||
| response = indexing_runner.indexing_estimate(current_user.current_tenant_id, extract_settings, | |||
| args['process_rule'], args['doc_form'], | |||
| args['doc_language']) | |||
| response = indexing_runner.indexing_estimate( | |||
| current_user.current_tenant_id, | |||
| extract_settings, | |||
| args["process_rule"], | |||
| args["doc_form"], | |||
| args["doc_language"], | |||
| ) | |||
| return response, 200 | |||
| class DataSourceNotionDatasetSyncApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @@ -240,7 +240,6 @@ class DataSourceNotionDatasetSyncApi(Resource): | |||
| class DataSourceNotionDocumentSyncApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @@ -258,10 +257,14 @@ class DataSourceNotionDocumentSyncApi(Resource): | |||
| return 200 | |||
| api.add_resource(DataSourceApi, '/data-source/integrates', '/data-source/integrates/<uuid:binding_id>/<string:action>') | |||
| api.add_resource(DataSourceNotionListApi, '/notion/pre-import/pages') | |||
| api.add_resource(DataSourceNotionApi, | |||
| '/notion/workspaces/<uuid:workspace_id>/pages/<uuid:page_id>/<string:page_type>/preview', | |||
| '/datasets/notion-indexing-estimate') | |||
| api.add_resource(DataSourceNotionDatasetSyncApi, '/datasets/<uuid:dataset_id>/notion/sync') | |||
| api.add_resource(DataSourceNotionDocumentSyncApi, '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/notion/sync') | |||
| api.add_resource(DataSourceApi, "/data-source/integrates", "/data-source/integrates/<uuid:binding_id>/<string:action>") | |||
| api.add_resource(DataSourceNotionListApi, "/notion/pre-import/pages") | |||
| api.add_resource( | |||
| DataSourceNotionApi, | |||
| "/notion/workspaces/<uuid:workspace_id>/pages/<uuid:page_id>/<string:page_type>/preview", | |||
| "/datasets/notion-indexing-estimate", | |||
| ) | |||
| api.add_resource(DataSourceNotionDatasetSyncApi, "/datasets/<uuid:dataset_id>/notion/sync") | |||
| api.add_resource( | |||
| DataSourceNotionDocumentSyncApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/notion/sync" | |||
| ) | |||
| @@ -31,45 +31,40 @@ from services.dataset_service import DatasetPermissionService, DatasetService, D | |||
| def _validate_name(name): | |||
| if not name or len(name) < 1 or len(name) > 40: | |||
| raise ValueError('Name must be between 1 to 40 characters.') | |||
| raise ValueError("Name must be between 1 to 40 characters.") | |||
| return name | |||
| def _validate_description_length(description): | |||
| if len(description) > 400: | |||
| raise ValueError('Description cannot exceed 400 characters.') | |||
| raise ValueError("Description cannot exceed 400 characters.") | |||
| return description | |||
| class DatasetListApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| def get(self): | |||
| page = request.args.get('page', default=1, type=int) | |||
| limit = request.args.get('limit', default=20, type=int) | |||
| ids = request.args.getlist('ids') | |||
| provider = request.args.get('provider', default="vendor") | |||
| search = request.args.get('keyword', default=None, type=str) | |||
| tag_ids = request.args.getlist('tag_ids') | |||
| page = request.args.get("page", default=1, type=int) | |||
| limit = request.args.get("limit", default=20, type=int) | |||
| ids = request.args.getlist("ids") | |||
| provider = request.args.get("provider", default="vendor") | |||
| search = request.args.get("keyword", default=None, type=str) | |||
| tag_ids = request.args.getlist("tag_ids") | |||
| if ids: | |||
| datasets, total = DatasetService.get_datasets_by_ids(ids, current_user.current_tenant_id) | |||
| else: | |||
| datasets, total = DatasetService.get_datasets(page, limit, provider, | |||
| current_user.current_tenant_id, current_user, search, tag_ids) | |||
| datasets, total = DatasetService.get_datasets( | |||
| page, limit, provider, current_user.current_tenant_id, current_user, search, tag_ids | |||
| ) | |||
| # check embedding setting | |||
| provider_manager = ProviderManager() | |||
| configurations = provider_manager.get_configurations( | |||
| tenant_id=current_user.current_tenant_id | |||
| ) | |||
| configurations = provider_manager.get_configurations(tenant_id=current_user.current_tenant_id) | |||
| embedding_models = configurations.get_models( | |||
| model_type=ModelType.TEXT_EMBEDDING, | |||
| only_active=True | |||
| ) | |||
| embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True) | |||
| model_names = [] | |||
| for embedding_model in embedding_models: | |||
| @@ -77,28 +72,22 @@ class DatasetListApi(Resource): | |||
| data = marshal(datasets, dataset_detail_fields) | |||
| for item in data: | |||
| if item['indexing_technique'] == 'high_quality': | |||
| if item["indexing_technique"] == "high_quality": | |||
| item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}" | |||
| if item_model in model_names: | |||
| item['embedding_available'] = True | |||
| item["embedding_available"] = True | |||
| else: | |||
| item['embedding_available'] = False | |||
| item["embedding_available"] = False | |||
| else: | |||
| item['embedding_available'] = True | |||
| item["embedding_available"] = True | |||
| if item.get('permission') == 'partial_members': | |||
| part_users_list = DatasetPermissionService.get_dataset_partial_member_list(item['id']) | |||
| item.update({'partial_member_list': part_users_list}) | |||
| if item.get("permission") == "partial_members": | |||
| part_users_list = DatasetPermissionService.get_dataset_partial_member_list(item["id"]) | |||
| item.update({"partial_member_list": part_users_list}) | |||
| else: | |||
| item.update({'partial_member_list': []}) | |||
| item.update({"partial_member_list": []}) | |||
| response = { | |||
| 'data': data, | |||
| 'has_more': len(datasets) == limit, | |||
| 'limit': limit, | |||
| 'total': total, | |||
| 'page': page | |||
| } | |||
| response = {"data": data, "has_more": len(datasets) == limit, "limit": limit, "total": total, "page": page} | |||
| return response, 200 | |||
| @setup_required | |||
| @@ -106,13 +95,21 @@ class DatasetListApi(Resource): | |||
| @account_initialization_required | |||
| def post(self): | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('name', nullable=False, required=True, | |||
| help='type is required. Name must be between 1 to 40 characters.', | |||
| type=_validate_name) | |||
| parser.add_argument('indexing_technique', type=str, location='json', | |||
| choices=Dataset.INDEXING_TECHNIQUE_LIST, | |||
| nullable=True, | |||
| help='Invalid indexing technique.') | |||
| parser.add_argument( | |||
| "name", | |||
| nullable=False, | |||
| required=True, | |||
| help="type is required. Name must be between 1 to 40 characters.", | |||
| type=_validate_name, | |||
| ) | |||
| parser.add_argument( | |||
| "indexing_technique", | |||
| type=str, | |||
| location="json", | |||
| choices=Dataset.INDEXING_TECHNIQUE_LIST, | |||
| nullable=True, | |||
| help="Invalid indexing technique.", | |||
| ) | |||
| args = parser.parse_args() | |||
| # The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator | |||
| @@ -122,9 +119,9 @@ class DatasetListApi(Resource): | |||
| try: | |||
| dataset = DatasetService.create_empty_dataset( | |||
| tenant_id=current_user.current_tenant_id, | |||
| name=args['name'], | |||
| indexing_technique=args['indexing_technique'], | |||
| account=current_user | |||
| name=args["name"], | |||
| indexing_technique=args["indexing_technique"], | |||
| account=current_user, | |||
| ) | |||
| except services.errors.dataset.DatasetNameDuplicateError: | |||
| raise DatasetNameDuplicateError() | |||
| @@ -142,42 +139,36 @@ class DatasetApi(Resource): | |||
| if dataset is None: | |||
| raise NotFound("Dataset not found.") | |||
| try: | |||
| DatasetService.check_dataset_permission( | |||
| dataset, current_user) | |||
| DatasetService.check_dataset_permission(dataset, current_user) | |||
| except services.errors.account.NoPermissionError as e: | |||
| raise Forbidden(str(e)) | |||
| data = marshal(dataset, dataset_detail_fields) | |||
| if data.get('permission') == 'partial_members': | |||
| if data.get("permission") == "partial_members": | |||
| part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str) | |||
| data.update({'partial_member_list': part_users_list}) | |||
| data.update({"partial_member_list": part_users_list}) | |||
| # check embedding setting | |||
| provider_manager = ProviderManager() | |||
| configurations = provider_manager.get_configurations( | |||
| tenant_id=current_user.current_tenant_id | |||
| ) | |||
| configurations = provider_manager.get_configurations(tenant_id=current_user.current_tenant_id) | |||
| embedding_models = configurations.get_models( | |||
| model_type=ModelType.TEXT_EMBEDDING, | |||
| only_active=True | |||
| ) | |||
| embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True) | |||
| model_names = [] | |||
| for embedding_model in embedding_models: | |||
| model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}") | |||
| if data['indexing_technique'] == 'high_quality': | |||
| if data["indexing_technique"] == "high_quality": | |||
| item_model = f"{data['embedding_model']}:{data['embedding_model_provider']}" | |||
| if item_model in model_names: | |||
| data['embedding_available'] = True | |||
| data["embedding_available"] = True | |||
| else: | |||
| data['embedding_available'] = False | |||
| data["embedding_available"] = False | |||
| else: | |||
| data['embedding_available'] = True | |||
| data["embedding_available"] = True | |||
| if data.get('permission') == 'partial_members': | |||
| if data.get("permission") == "partial_members": | |||
| part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str) | |||
| data.update({'partial_member_list': part_users_list}) | |||
| data.update({"partial_member_list": part_users_list}) | |||
| return data, 200 | |||
| @@ -191,42 +182,49 @@ class DatasetApi(Resource): | |||
| raise NotFound("Dataset not found.") | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('name', nullable=False, | |||
| help='type is required. Name must be between 1 to 40 characters.', | |||
| type=_validate_name) | |||
| parser.add_argument('description', | |||
| location='json', store_missing=False, | |||
| type=_validate_description_length) | |||
| parser.add_argument('indexing_technique', type=str, location='json', | |||
| choices=Dataset.INDEXING_TECHNIQUE_LIST, | |||
| nullable=True, | |||
| help='Invalid indexing technique.') | |||
| parser.add_argument('permission', type=str, location='json', choices=( | |||
| DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM), help='Invalid permission.' | |||
| ) | |||
| parser.add_argument('embedding_model', type=str, | |||
| location='json', help='Invalid embedding model.') | |||
| parser.add_argument('embedding_model_provider', type=str, | |||
| location='json', help='Invalid embedding model provider.') | |||
| parser.add_argument('retrieval_model', type=dict, location='json', help='Invalid retrieval model.') | |||
| parser.add_argument('partial_member_list', type=list, location='json', help='Invalid parent user list.') | |||
| parser.add_argument( | |||
| "name", | |||
| nullable=False, | |||
| help="type is required. Name must be between 1 to 40 characters.", | |||
| type=_validate_name, | |||
| ) | |||
| parser.add_argument("description", location="json", store_missing=False, type=_validate_description_length) | |||
| parser.add_argument( | |||
| "indexing_technique", | |||
| type=str, | |||
| location="json", | |||
| choices=Dataset.INDEXING_TECHNIQUE_LIST, | |||
| nullable=True, | |||
| help="Invalid indexing technique.", | |||
| ) | |||
| parser.add_argument( | |||
| "permission", | |||
| type=str, | |||
| location="json", | |||
| choices=(DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM), | |||
| help="Invalid permission.", | |||
| ) | |||
| parser.add_argument("embedding_model", type=str, location="json", help="Invalid embedding model.") | |||
| parser.add_argument( | |||
| "embedding_model_provider", type=str, location="json", help="Invalid embedding model provider." | |||
| ) | |||
| parser.add_argument("retrieval_model", type=dict, location="json", help="Invalid retrieval model.") | |||
| parser.add_argument("partial_member_list", type=list, location="json", help="Invalid parent user list.") | |||
| args = parser.parse_args() | |||
| data = request.get_json() | |||
| # check embedding model setting | |||
| if data.get('indexing_technique') == 'high_quality': | |||
| DatasetService.check_embedding_model_setting(dataset.tenant_id, | |||
| data.get('embedding_model_provider'), | |||
| data.get('embedding_model') | |||
| ) | |||
| if data.get("indexing_technique") == "high_quality": | |||
| DatasetService.check_embedding_model_setting( | |||
| dataset.tenant_id, data.get("embedding_model_provider"), data.get("embedding_model") | |||
| ) | |||
| # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator | |||
| DatasetPermissionService.check_permission( | |||
| current_user, dataset, data.get('permission'), data.get('partial_member_list') | |||
| current_user, dataset, data.get("permission"), data.get("partial_member_list") | |||
| ) | |||
| dataset = DatasetService.update_dataset( | |||
| dataset_id_str, args, current_user) | |||
| dataset = DatasetService.update_dataset(dataset_id_str, args, current_user) | |||
| if dataset is None: | |||
| raise NotFound("Dataset not found.") | |||
| @@ -234,16 +232,19 @@ class DatasetApi(Resource): | |||
| result_data = marshal(dataset, dataset_detail_fields) | |||
| tenant_id = current_user.current_tenant_id | |||
| if data.get('partial_member_list') and data.get('permission') == 'partial_members': | |||
| if data.get("partial_member_list") and data.get("permission") == "partial_members": | |||
| DatasetPermissionService.update_partial_member_list( | |||
| tenant_id, dataset_id_str, data.get('partial_member_list') | |||
| tenant_id, dataset_id_str, data.get("partial_member_list") | |||
| ) | |||
| # clear partial member list when permission is only_me or all_team_members | |||
| elif data.get('permission') == DatasetPermissionEnum.ONLY_ME or data.get('permission') == DatasetPermissionEnum.ALL_TEAM: | |||
| elif ( | |||
| data.get("permission") == DatasetPermissionEnum.ONLY_ME | |||
| or data.get("permission") == DatasetPermissionEnum.ALL_TEAM | |||
| ): | |||
| DatasetPermissionService.clear_partial_member_list(dataset_id_str) | |||
| partial_member_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str) | |||
| result_data.update({'partial_member_list': partial_member_list}) | |||
| result_data.update({"partial_member_list": partial_member_list}) | |||
| return result_data, 200 | |||
| @@ -260,12 +261,13 @@ class DatasetApi(Resource): | |||
| try: | |||
| if DatasetService.delete_dataset(dataset_id_str, current_user): | |||
| DatasetPermissionService.clear_partial_member_list(dataset_id_str) | |||
| return {'result': 'success'}, 204 | |||
| return {"result": "success"}, 204 | |||
| else: | |||
| raise NotFound("Dataset not found.") | |||
| except services.errors.dataset.DatasetInUseError: | |||
| raise DatasetInUseError() | |||
| class DatasetUseCheckApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @@ -274,10 +276,10 @@ class DatasetUseCheckApi(Resource): | |||
| dataset_id_str = str(dataset_id) | |||
| dataset_is_using = DatasetService.dataset_use_check(dataset_id_str) | |||
| return {'is_using': dataset_is_using}, 200 | |||
| return {"is_using": dataset_is_using}, 200 | |||
| class DatasetQueryApi(Resource): | |||
| class DatasetQueryApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @@ -292,51 +294,53 @@ class DatasetQueryApi(Resource): | |||
| except services.errors.account.NoPermissionError as e: | |||
| raise Forbidden(str(e)) | |||
| page = request.args.get('page', default=1, type=int) | |||
| limit = request.args.get('limit', default=20, type=int) | |||
| page = request.args.get("page", default=1, type=int) | |||
| limit = request.args.get("limit", default=20, type=int) | |||
| dataset_queries, total = DatasetService.get_dataset_queries( | |||
| dataset_id=dataset.id, | |||
| page=page, | |||
| per_page=limit | |||
| ) | |||
| dataset_queries, total = DatasetService.get_dataset_queries(dataset_id=dataset.id, page=page, per_page=limit) | |||
| response = { | |||
| 'data': marshal(dataset_queries, dataset_query_detail_fields), | |||
| 'has_more': len(dataset_queries) == limit, | |||
| 'limit': limit, | |||
| 'total': total, | |||
| 'page': page | |||
| "data": marshal(dataset_queries, dataset_query_detail_fields), | |||
| "has_more": len(dataset_queries) == limit, | |||
| "limit": limit, | |||
| "total": total, | |||
| "page": page, | |||
| } | |||
| return response, 200 | |||
| class DatasetIndexingEstimateApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| def post(self): | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('info_list', type=dict, required=True, nullable=True, location='json') | |||
| parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json') | |||
| parser.add_argument('indexing_technique', type=str, required=True, | |||
| choices=Dataset.INDEXING_TECHNIQUE_LIST, | |||
| nullable=True, location='json') | |||
| parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json') | |||
| parser.add_argument('dataset_id', type=str, required=False, nullable=False, location='json') | |||
| parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False, | |||
| location='json') | |||
| parser.add_argument("info_list", type=dict, required=True, nullable=True, location="json") | |||
| parser.add_argument("process_rule", type=dict, required=True, nullable=True, location="json") | |||
| parser.add_argument( | |||
| "indexing_technique", | |||
| type=str, | |||
| required=True, | |||
| choices=Dataset.INDEXING_TECHNIQUE_LIST, | |||
| nullable=True, | |||
| location="json", | |||
| ) | |||
| parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json") | |||
| parser.add_argument("dataset_id", type=str, required=False, nullable=False, location="json") | |||
| parser.add_argument( | |||
| "doc_language", type=str, default="English", required=False, nullable=False, location="json" | |||
| ) | |||
| args = parser.parse_args() | |||
| # validate args | |||
| DocumentService.estimate_args_validate(args) | |||
| extract_settings = [] | |||
| if args['info_list']['data_source_type'] == 'upload_file': | |||
| file_ids = args['info_list']['file_info_list']['file_ids'] | |||
| file_details = db.session.query(UploadFile).filter( | |||
| UploadFile.tenant_id == current_user.current_tenant_id, | |||
| UploadFile.id.in_(file_ids) | |||
| ).all() | |||
| if args["info_list"]["data_source_type"] == "upload_file": | |||
| file_ids = args["info_list"]["file_info_list"]["file_ids"] | |||
| file_details = ( | |||
| db.session.query(UploadFile) | |||
| .filter(UploadFile.tenant_id == current_user.current_tenant_id, UploadFile.id.in_(file_ids)) | |||
| .all() | |||
| ) | |||
| if file_details is None: | |||
| raise NotFound("File not found.") | |||
| @@ -344,55 +348,58 @@ class DatasetIndexingEstimateApi(Resource): | |||
| if file_details: | |||
| for file_detail in file_details: | |||
| extract_setting = ExtractSetting( | |||
| datasource_type="upload_file", | |||
| upload_file=file_detail, | |||
| document_model=args['doc_form'] | |||
| datasource_type="upload_file", upload_file=file_detail, document_model=args["doc_form"] | |||
| ) | |||
| extract_settings.append(extract_setting) | |||
| elif args['info_list']['data_source_type'] == 'notion_import': | |||
| notion_info_list = args['info_list']['notion_info_list'] | |||
| elif args["info_list"]["data_source_type"] == "notion_import": | |||
| notion_info_list = args["info_list"]["notion_info_list"] | |||
| for notion_info in notion_info_list: | |||
| workspace_id = notion_info['workspace_id'] | |||
| for page in notion_info['pages']: | |||
| workspace_id = notion_info["workspace_id"] | |||
| for page in notion_info["pages"]: | |||
| extract_setting = ExtractSetting( | |||
| datasource_type="notion_import", | |||
| notion_info={ | |||
| "notion_workspace_id": workspace_id, | |||
| "notion_obj_id": page['page_id'], | |||
| "notion_page_type": page['type'], | |||
| "tenant_id": current_user.current_tenant_id | |||
| "notion_obj_id": page["page_id"], | |||
| "notion_page_type": page["type"], | |||
| "tenant_id": current_user.current_tenant_id, | |||
| }, | |||
| document_model=args['doc_form'] | |||
| document_model=args["doc_form"], | |||
| ) | |||
| extract_settings.append(extract_setting) | |||
| elif args['info_list']['data_source_type'] == 'website_crawl': | |||
| website_info_list = args['info_list']['website_info_list'] | |||
| for url in website_info_list['urls']: | |||
| elif args["info_list"]["data_source_type"] == "website_crawl": | |||
| website_info_list = args["info_list"]["website_info_list"] | |||
| for url in website_info_list["urls"]: | |||
| extract_setting = ExtractSetting( | |||
| datasource_type="website_crawl", | |||
| website_info={ | |||
| "provider": website_info_list['provider'], | |||
| "job_id": website_info_list['job_id'], | |||
| "provider": website_info_list["provider"], | |||
| "job_id": website_info_list["job_id"], | |||
| "url": url, | |||
| "tenant_id": current_user.current_tenant_id, | |||
| "mode": 'crawl', | |||
| "only_main_content": website_info_list['only_main_content'] | |||
| "mode": "crawl", | |||
| "only_main_content": website_info_list["only_main_content"], | |||
| }, | |||
| document_model=args['doc_form'] | |||
| document_model=args["doc_form"], | |||
| ) | |||
| extract_settings.append(extract_setting) | |||
| else: | |||
| raise ValueError('Data source type not support') | |||
| raise ValueError("Data source type not support") | |||
| indexing_runner = IndexingRunner() | |||
| try: | |||
| response = indexing_runner.indexing_estimate(current_user.current_tenant_id, extract_settings, | |||
| args['process_rule'], args['doc_form'], | |||
| args['doc_language'], args['dataset_id'], | |||
| args['indexing_technique']) | |||
| response = indexing_runner.indexing_estimate( | |||
| current_user.current_tenant_id, | |||
| extract_settings, | |||
| args["process_rule"], | |||
| args["doc_form"], | |||
| args["doc_language"], | |||
| args["dataset_id"], | |||
| args["indexing_technique"], | |||
| ) | |||
| except LLMBadRequestError: | |||
| raise ProviderNotInitializeError( | |||
| "No Embedding Model available. Please configure a valid provider " | |||
| "in the Settings -> Model Provider.") | |||
| "No Embedding Model available. Please configure a valid provider " "in the Settings -> Model Provider." | |||
| ) | |||
| except ProviderTokenNotInitError as ex: | |||
| raise ProviderNotInitializeError(ex.description) | |||
| except Exception as e: | |||
| @@ -402,7 +409,6 @@ class DatasetIndexingEstimateApi(Resource): | |||
| class DatasetRelatedAppListApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @@ -426,52 +432,52 @@ class DatasetRelatedAppListApi(Resource): | |||
| if app_model: | |||
| related_apps.append(app_model) | |||
| return { | |||
| 'data': related_apps, | |||
| 'total': len(related_apps) | |||
| }, 200 | |||
| return {"data": related_apps, "total": len(related_apps)}, 200 | |||
| class DatasetIndexingStatusApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| def get(self, dataset_id): | |||
| dataset_id = str(dataset_id) | |||
| documents = db.session.query(Document).filter( | |||
| Document.dataset_id == dataset_id, | |||
| Document.tenant_id == current_user.current_tenant_id | |||
| ).all() | |||
| documents = ( | |||
| db.session.query(Document) | |||
| .filter(Document.dataset_id == dataset_id, Document.tenant_id == current_user.current_tenant_id) | |||
| .all() | |||
| ) | |||
| documents_status = [] | |||
| for document in documents: | |||
| completed_segments = DocumentSegment.query.filter(DocumentSegment.completed_at.isnot(None), | |||
| DocumentSegment.document_id == str(document.id), | |||
| DocumentSegment.status != 're_segment').count() | |||
| total_segments = DocumentSegment.query.filter(DocumentSegment.document_id == str(document.id), | |||
| DocumentSegment.status != 're_segment').count() | |||
| completed_segments = DocumentSegment.query.filter( | |||
| DocumentSegment.completed_at.isnot(None), | |||
| DocumentSegment.document_id == str(document.id), | |||
| DocumentSegment.status != "re_segment", | |||
| ).count() | |||
| total_segments = DocumentSegment.query.filter( | |||
| DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment" | |||
| ).count() | |||
| document.completed_segments = completed_segments | |||
| document.total_segments = total_segments | |||
| documents_status.append(marshal(document, document_status_fields)) | |||
| data = { | |||
| 'data': documents_status | |||
| } | |||
| data = {"data": documents_status} | |||
| return data | |||
| class DatasetApiKeyApi(Resource): | |||
| max_keys = 10 | |||
| token_prefix = 'dataset-' | |||
| resource_type = 'dataset' | |||
| token_prefix = "dataset-" | |||
| resource_type = "dataset" | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @marshal_with(api_key_list) | |||
| def get(self): | |||
| keys = db.session.query(ApiToken). \ | |||
| filter(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id). \ | |||
| all() | |||
| keys = ( | |||
| db.session.query(ApiToken) | |||
| .filter(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id) | |||
| .all() | |||
| ) | |||
| return {"items": keys} | |||
| @setup_required | |||
| @@ -483,15 +489,17 @@ class DatasetApiKeyApi(Resource): | |||
| if not current_user.is_admin_or_owner: | |||
| raise Forbidden() | |||
| current_key_count = db.session.query(ApiToken). \ | |||
| filter(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id). \ | |||
| count() | |||
| current_key_count = ( | |||
| db.session.query(ApiToken) | |||
| .filter(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id) | |||
| .count() | |||
| ) | |||
| if current_key_count >= self.max_keys: | |||
| flask_restful.abort( | |||
| 400, | |||
| message=f"Cannot create more than {self.max_keys} API keys for this resource type.", | |||
| code='max_keys_exceeded' | |||
| code="max_keys_exceeded", | |||
| ) | |||
| key = ApiToken.generate_api_key(self.token_prefix, 24) | |||
| @@ -505,7 +513,7 @@ class DatasetApiKeyApi(Resource): | |||
| class DatasetApiDeleteApi(Resource): | |||
| resource_type = 'dataset' | |||
| resource_type = "dataset" | |||
| @setup_required | |||
| @login_required | |||
| @@ -517,18 +525,23 @@ class DatasetApiDeleteApi(Resource): | |||
| if not current_user.is_admin_or_owner: | |||
| raise Forbidden() | |||
| key = db.session.query(ApiToken). \ | |||
| filter(ApiToken.tenant_id == current_user.current_tenant_id, ApiToken.type == self.resource_type, | |||
| ApiToken.id == api_key_id). \ | |||
| first() | |||
| key = ( | |||
| db.session.query(ApiToken) | |||
| .filter( | |||
| ApiToken.tenant_id == current_user.current_tenant_id, | |||
| ApiToken.type == self.resource_type, | |||
| ApiToken.id == api_key_id, | |||
| ) | |||
| .first() | |||
| ) | |||
| if key is None: | |||
| flask_restful.abort(404, message='API key not found') | |||
| flask_restful.abort(404, message="API key not found") | |||
| db.session.query(ApiToken).filter(ApiToken.id == api_key_id).delete() | |||
| db.session.commit() | |||
| return {'result': 'success'}, 204 | |||
| return {"result": "success"}, 204 | |||
| class DatasetApiBaseUrlApi(Resource): | |||
| @@ -537,8 +550,10 @@ class DatasetApiBaseUrlApi(Resource): | |||
| @account_initialization_required | |||
| def get(self): | |||
| return { | |||
| 'api_base_url': (dify_config.SERVICE_API_URL if dify_config.SERVICE_API_URL | |||
| else request.host_url.rstrip('/')) + '/v1' | |||
| "api_base_url": ( | |||
| dify_config.SERVICE_API_URL if dify_config.SERVICE_API_URL else request.host_url.rstrip("/") | |||
| ) | |||
| + "/v1" | |||
| } | |||
| @@ -549,15 +564,26 @@ class DatasetRetrievalSettingApi(Resource): | |||
| def get(self): | |||
| vector_type = dify_config.VECTOR_STORE | |||
| match vector_type: | |||
| case VectorType.MILVUS | VectorType.RELYT | VectorType.PGVECTOR | VectorType.TIDB_VECTOR | VectorType.CHROMA | VectorType.TENCENT: | |||
| return { | |||
| 'retrieval_method': [ | |||
| RetrievalMethod.SEMANTIC_SEARCH.value | |||
| ] | |||
| } | |||
| case VectorType.QDRANT | VectorType.WEAVIATE | VectorType.OPENSEARCH | VectorType.ANALYTICDB | VectorType.MYSCALE | VectorType.ORACLE | VectorType.ELASTICSEARCH: | |||
| case ( | |||
| VectorType.MILVUS | |||
| | VectorType.RELYT | |||
| | VectorType.PGVECTOR | |||
| | VectorType.TIDB_VECTOR | |||
| | VectorType.CHROMA | |||
| | VectorType.TENCENT | |||
| ): | |||
| return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]} | |||
| case ( | |||
| VectorType.QDRANT | |||
| | VectorType.WEAVIATE | |||
| | VectorType.OPENSEARCH | |||
| | VectorType.ANALYTICDB | |||
| | VectorType.MYSCALE | |||
| | VectorType.ORACLE | |||
| | VectorType.ELASTICSEARCH | |||
| ): | |||
| return { | |||
| 'retrieval_method': [ | |||
| "retrieval_method": [ | |||
| RetrievalMethod.SEMANTIC_SEARCH.value, | |||
| RetrievalMethod.FULL_TEXT_SEARCH.value, | |||
| RetrievalMethod.HYBRID_SEARCH.value, | |||
| @@ -573,15 +599,27 @@ class DatasetRetrievalSettingMockApi(Resource): | |||
| @account_initialization_required | |||
| def get(self, vector_type): | |||
| match vector_type: | |||
| case VectorType.MILVUS | VectorType.RELYT | VectorType.TIDB_VECTOR | VectorType.CHROMA | VectorType.TENCENT | VectorType.PGVECTO_RS: | |||
| case ( | |||
| VectorType.MILVUS | |||
| | VectorType.RELYT | |||
| | VectorType.TIDB_VECTOR | |||
| | VectorType.CHROMA | |||
| | VectorType.TENCENT | |||
| | VectorType.PGVECTO_RS | |||
| ): | |||
| return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]} | |||
| case ( | |||
| VectorType.QDRANT | |||
| | VectorType.WEAVIATE | |||
| | VectorType.OPENSEARCH | |||
| | VectorType.ANALYTICDB | |||
| | VectorType.MYSCALE | |||
| | VectorType.ORACLE | |||
| | VectorType.ELASTICSEARCH | |||
| | VectorType.PGVECTOR | |||
| ): | |||
| return { | |||
| 'retrieval_method': [ | |||
| RetrievalMethod.SEMANTIC_SEARCH.value | |||
| ] | |||
| } | |||
| case VectorType.QDRANT | VectorType.WEAVIATE | VectorType.OPENSEARCH | VectorType.ANALYTICDB | VectorType.MYSCALE | VectorType.ORACLE | VectorType.ELASTICSEARCH | VectorType.PGVECTOR: | |||
| return { | |||
| 'retrieval_method': [ | |||
| "retrieval_method": [ | |||
| RetrievalMethod.SEMANTIC_SEARCH.value, | |||
| RetrievalMethod.FULL_TEXT_SEARCH.value, | |||
| RetrievalMethod.HYBRID_SEARCH.value, | |||
| @@ -591,7 +629,6 @@ class DatasetRetrievalSettingMockApi(Resource): | |||
| raise ValueError(f"Unsupported vector db type {vector_type}.") | |||
| class DatasetErrorDocs(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @@ -603,10 +640,7 @@ class DatasetErrorDocs(Resource): | |||
| raise NotFound("Dataset not found.") | |||
| results = DocumentService.get_error_documents_by_dataset_id(dataset_id_str) | |||
| return { | |||
| 'data': [marshal(item, document_status_fields) for item in results], | |||
| 'total': len(results) | |||
| }, 200 | |||
| return {"data": [marshal(item, document_status_fields) for item in results], "total": len(results)}, 200 | |||
| class DatasetPermissionUserListApi(Resource): | |||
| @@ -626,21 +660,21 @@ class DatasetPermissionUserListApi(Resource): | |||
| partial_members_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str) | |||
| return { | |||
| 'data': partial_members_list, | |||
| "data": partial_members_list, | |||
| }, 200 | |||
| api.add_resource(DatasetListApi, '/datasets') | |||
| api.add_resource(DatasetApi, '/datasets/<uuid:dataset_id>') | |||
| api.add_resource(DatasetUseCheckApi, '/datasets/<uuid:dataset_id>/use-check') | |||
| api.add_resource(DatasetQueryApi, '/datasets/<uuid:dataset_id>/queries') | |||
| api.add_resource(DatasetErrorDocs, '/datasets/<uuid:dataset_id>/error-docs') | |||
| api.add_resource(DatasetIndexingEstimateApi, '/datasets/indexing-estimate') | |||
| api.add_resource(DatasetRelatedAppListApi, '/datasets/<uuid:dataset_id>/related-apps') | |||
| api.add_resource(DatasetIndexingStatusApi, '/datasets/<uuid:dataset_id>/indexing-status') | |||
| api.add_resource(DatasetApiKeyApi, '/datasets/api-keys') | |||
| api.add_resource(DatasetApiDeleteApi, '/datasets/api-keys/<uuid:api_key_id>') | |||
| api.add_resource(DatasetApiBaseUrlApi, '/datasets/api-base-info') | |||
| api.add_resource(DatasetRetrievalSettingApi, '/datasets/retrieval-setting') | |||
| api.add_resource(DatasetRetrievalSettingMockApi, '/datasets/retrieval-setting/<string:vector_type>') | |||
| api.add_resource(DatasetPermissionUserListApi, '/datasets/<uuid:dataset_id>/permission-part-users') | |||
| api.add_resource(DatasetListApi, "/datasets") | |||
| api.add_resource(DatasetApi, "/datasets/<uuid:dataset_id>") | |||
| api.add_resource(DatasetUseCheckApi, "/datasets/<uuid:dataset_id>/use-check") | |||
| api.add_resource(DatasetQueryApi, "/datasets/<uuid:dataset_id>/queries") | |||
| api.add_resource(DatasetErrorDocs, "/datasets/<uuid:dataset_id>/error-docs") | |||
| api.add_resource(DatasetIndexingEstimateApi, "/datasets/indexing-estimate") | |||
| api.add_resource(DatasetRelatedAppListApi, "/datasets/<uuid:dataset_id>/related-apps") | |||
| api.add_resource(DatasetIndexingStatusApi, "/datasets/<uuid:dataset_id>/indexing-status") | |||
| api.add_resource(DatasetApiKeyApi, "/datasets/api-keys") | |||
| api.add_resource(DatasetApiDeleteApi, "/datasets/api-keys/<uuid:api_key_id>") | |||
| api.add_resource(DatasetApiBaseUrlApi, "/datasets/api-base-info") | |||
| api.add_resource(DatasetRetrievalSettingApi, "/datasets/retrieval-setting") | |||
| api.add_resource(DatasetRetrievalSettingMockApi, "/datasets/retrieval-setting/<string:vector_type>") | |||
| api.add_resource(DatasetPermissionUserListApi, "/datasets/<uuid:dataset_id>/permission-part-users") | |||
| @@ -40,7 +40,7 @@ class DatasetDocumentSegmentListApi(Resource): | |||
| document_id = str(document_id) | |||
| dataset = DatasetService.get_dataset(dataset_id) | |||
| if not dataset: | |||
| raise NotFound('Dataset not found.') | |||
| raise NotFound("Dataset not found.") | |||
| try: | |||
| DatasetService.check_dataset_permission(dataset, current_user) | |||
| @@ -50,37 +50,33 @@ class DatasetDocumentSegmentListApi(Resource): | |||
| document = DocumentService.get_document(dataset_id, document_id) | |||
| if not document: | |||
| raise NotFound('Document not found.') | |||
| raise NotFound("Document not found.") | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('last_id', type=str, default=None, location='args') | |||
| parser.add_argument('limit', type=int, default=20, location='args') | |||
| parser.add_argument('status', type=str, | |||
| action='append', default=[], location='args') | |||
| parser.add_argument('hit_count_gte', type=int, | |||
| default=None, location='args') | |||
| parser.add_argument('enabled', type=str, default='all', location='args') | |||
| parser.add_argument('keyword', type=str, default=None, location='args') | |||
| parser.add_argument("last_id", type=str, default=None, location="args") | |||
| parser.add_argument("limit", type=int, default=20, location="args") | |||
| parser.add_argument("status", type=str, action="append", default=[], location="args") | |||
| parser.add_argument("hit_count_gte", type=int, default=None, location="args") | |||
| parser.add_argument("enabled", type=str, default="all", location="args") | |||
| parser.add_argument("keyword", type=str, default=None, location="args") | |||
| args = parser.parse_args() | |||
| last_id = args['last_id'] | |||
| limit = min(args['limit'], 100) | |||
| status_list = args['status'] | |||
| hit_count_gte = args['hit_count_gte'] | |||
| keyword = args['keyword'] | |||
| last_id = args["last_id"] | |||
| limit = min(args["limit"], 100) | |||
| status_list = args["status"] | |||
| hit_count_gte = args["hit_count_gte"] | |||
| keyword = args["keyword"] | |||
| query = DocumentSegment.query.filter( | |||
| DocumentSegment.document_id == str(document_id), | |||
| DocumentSegment.tenant_id == current_user.current_tenant_id | |||
| DocumentSegment.document_id == str(document_id), DocumentSegment.tenant_id == current_user.current_tenant_id | |||
| ) | |||
| if last_id is not None: | |||
| last_segment = db.session.get(DocumentSegment, str(last_id)) | |||
| if last_segment: | |||
| query = query.filter( | |||
| DocumentSegment.position > last_segment.position) | |||
| query = query.filter(DocumentSegment.position > last_segment.position) | |||
| else: | |||
| return {'data': [], 'has_more': False, 'limit': limit}, 200 | |||
| return {"data": [], "has_more": False, "limit": limit}, 200 | |||
| if status_list: | |||
| query = query.filter(DocumentSegment.status.in_(status_list)) | |||
| @@ -89,12 +85,12 @@ class DatasetDocumentSegmentListApi(Resource): | |||
| query = query.filter(DocumentSegment.hit_count >= hit_count_gte) | |||
| if keyword: | |||
| query = query.where(DocumentSegment.content.ilike(f'%{keyword}%')) | |||
| query = query.where(DocumentSegment.content.ilike(f"%{keyword}%")) | |||
| if args['enabled'].lower() != 'all': | |||
| if args['enabled'].lower() == 'true': | |||
| if args["enabled"].lower() != "all": | |||
| if args["enabled"].lower() == "true": | |||
| query = query.filter(DocumentSegment.enabled == True) | |||
| elif args['enabled'].lower() == 'false': | |||
| elif args["enabled"].lower() == "false": | |||
| query = query.filter(DocumentSegment.enabled == False) | |||
| total = query.count() | |||
| @@ -106,11 +102,11 @@ class DatasetDocumentSegmentListApi(Resource): | |||
| segments = segments[:-1] | |||
| return { | |||
| 'data': marshal(segments, segment_fields), | |||
| 'doc_form': document.doc_form, | |||
| 'has_more': has_more, | |||
| 'limit': limit, | |||
| 'total': total | |||
| "data": marshal(segments, segment_fields), | |||
| "doc_form": document.doc_form, | |||
| "has_more": has_more, | |||
| "limit": limit, | |||
| "total": total, | |||
| }, 200 | |||
| @@ -118,12 +114,12 @@ class DatasetDocumentSegmentApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @cloud_edition_billing_resource_check('vector_space') | |||
| @cloud_edition_billing_resource_check("vector_space") | |||
| def patch(self, dataset_id, segment_id, action): | |||
| dataset_id = str(dataset_id) | |||
| dataset = DatasetService.get_dataset(dataset_id) | |||
| if not dataset: | |||
| raise NotFound('Dataset not found.') | |||
| raise NotFound("Dataset not found.") | |||
| # check user's model setting | |||
| DatasetService.check_dataset_model_setting(dataset) | |||
| # The role of the current user in the ta table must be admin, owner, or editor | |||
| @@ -134,7 +130,7 @@ class DatasetDocumentSegmentApi(Resource): | |||
| DatasetService.check_dataset_permission(dataset, current_user) | |||
| except services.errors.account.NoPermissionError as e: | |||
| raise Forbidden(str(e)) | |||
| if dataset.indexing_technique == 'high_quality': | |||
| if dataset.indexing_technique == "high_quality": | |||
| # check embedding model setting | |||
| try: | |||
| model_manager = ModelManager() | |||
| @@ -142,32 +138,32 @@ class DatasetDocumentSegmentApi(Resource): | |||
| tenant_id=current_user.current_tenant_id, | |||
| provider=dataset.embedding_model_provider, | |||
| model_type=ModelType.TEXT_EMBEDDING, | |||
| model=dataset.embedding_model | |||
| model=dataset.embedding_model, | |||
| ) | |||
| except LLMBadRequestError: | |||
| raise ProviderNotInitializeError( | |||
| "No Embedding Model available. Please configure a valid provider " | |||
| "in the Settings -> Model Provider.") | |||
| "in the Settings -> Model Provider." | |||
| ) | |||
| except ProviderTokenNotInitError as ex: | |||
| raise ProviderNotInitializeError(ex.description) | |||
| segment = DocumentSegment.query.filter( | |||
| DocumentSegment.id == str(segment_id), | |||
| DocumentSegment.tenant_id == current_user.current_tenant_id | |||
| DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id | |||
| ).first() | |||
| if not segment: | |||
| raise NotFound('Segment not found.') | |||
| raise NotFound("Segment not found.") | |||
| if segment.status != 'completed': | |||
| raise NotFound('Segment is not completed, enable or disable function is not allowed') | |||
| if segment.status != "completed": | |||
| raise NotFound("Segment is not completed, enable or disable function is not allowed") | |||
| document_indexing_cache_key = 'document_{}_indexing'.format(segment.document_id) | |||
| document_indexing_cache_key = "document_{}_indexing".format(segment.document_id) | |||
| cache_result = redis_client.get(document_indexing_cache_key) | |||
| if cache_result is not None: | |||
| raise InvalidActionError("Document is being indexed, please try again later") | |||
| indexing_cache_key = 'segment_{}_indexing'.format(segment.id) | |||
| indexing_cache_key = "segment_{}_indexing".format(segment.id) | |||
| cache_result = redis_client.get(indexing_cache_key) | |||
| if cache_result is not None: | |||
| raise InvalidActionError("Segment is being indexed, please try again later") | |||
| @@ -186,7 +182,7 @@ class DatasetDocumentSegmentApi(Resource): | |||
| enable_segment_to_index_task.delay(segment.id) | |||
| return {'result': 'success'}, 200 | |||
| return {"result": "success"}, 200 | |||
| elif action == "disable": | |||
| if not segment.enabled: | |||
| raise InvalidActionError("Segment is already disabled.") | |||
| @@ -201,7 +197,7 @@ class DatasetDocumentSegmentApi(Resource): | |||
| disable_segment_from_index_task.delay(segment.id) | |||
| return {'result': 'success'}, 200 | |||
| return {"result": "success"}, 200 | |||
| else: | |||
| raise InvalidActionError() | |||
| @@ -210,35 +206,36 @@ class DatasetDocumentSegmentAddApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @cloud_edition_billing_resource_check('vector_space') | |||
| @cloud_edition_billing_knowledge_limit_check('add_segment') | |||
| @cloud_edition_billing_resource_check("vector_space") | |||
| @cloud_edition_billing_knowledge_limit_check("add_segment") | |||
| def post(self, dataset_id, document_id): | |||
| # check dataset | |||
| dataset_id = str(dataset_id) | |||
| dataset = DatasetService.get_dataset(dataset_id) | |||
| if not dataset: | |||
| raise NotFound('Dataset not found.') | |||
| raise NotFound("Dataset not found.") | |||
| # check document | |||
| document_id = str(document_id) | |||
| document = DocumentService.get_document(dataset_id, document_id) | |||
| if not document: | |||
| raise NotFound('Document not found.') | |||
| raise NotFound("Document not found.") | |||
| if not current_user.is_editor: | |||
| raise Forbidden() | |||
| # check embedding model setting | |||
| if dataset.indexing_technique == 'high_quality': | |||
| if dataset.indexing_technique == "high_quality": | |||
| try: | |||
| model_manager = ModelManager() | |||
| model_manager.get_model_instance( | |||
| tenant_id=current_user.current_tenant_id, | |||
| provider=dataset.embedding_model_provider, | |||
| model_type=ModelType.TEXT_EMBEDDING, | |||
| model=dataset.embedding_model | |||
| model=dataset.embedding_model, | |||
| ) | |||
| except LLMBadRequestError: | |||
| raise ProviderNotInitializeError( | |||
| "No Embedding Model available. Please configure a valid provider " | |||
| "in the Settings -> Model Provider.") | |||
| "in the Settings -> Model Provider." | |||
| ) | |||
| except ProviderTokenNotInitError as ex: | |||
| raise ProviderNotInitializeError(ex.description) | |||
| try: | |||
| @@ -247,37 +244,34 @@ class DatasetDocumentSegmentAddApi(Resource): | |||
| raise Forbidden(str(e)) | |||
| # validate args | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('content', type=str, required=True, nullable=False, location='json') | |||
| parser.add_argument('answer', type=str, required=False, nullable=True, location='json') | |||
| parser.add_argument('keywords', type=list, required=False, nullable=True, location='json') | |||
| parser.add_argument("content", type=str, required=True, nullable=False, location="json") | |||
| parser.add_argument("answer", type=str, required=False, nullable=True, location="json") | |||
| parser.add_argument("keywords", type=list, required=False, nullable=True, location="json") | |||
| args = parser.parse_args() | |||
| SegmentService.segment_create_args_validate(args, document) | |||
| segment = SegmentService.create_segment(args, document, dataset) | |||
| return { | |||
| 'data': marshal(segment, segment_fields), | |||
| 'doc_form': document.doc_form | |||
| }, 200 | |||
| return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200 | |||
| class DatasetDocumentSegmentUpdateApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @cloud_edition_billing_resource_check('vector_space') | |||
| @cloud_edition_billing_resource_check("vector_space") | |||
| def patch(self, dataset_id, document_id, segment_id): | |||
| # check dataset | |||
| dataset_id = str(dataset_id) | |||
| dataset = DatasetService.get_dataset(dataset_id) | |||
| if not dataset: | |||
| raise NotFound('Dataset not found.') | |||
| raise NotFound("Dataset not found.") | |||
| # check user's model setting | |||
| DatasetService.check_dataset_model_setting(dataset) | |||
| # check document | |||
| document_id = str(document_id) | |||
| document = DocumentService.get_document(dataset_id, document_id) | |||
| if not document: | |||
| raise NotFound('Document not found.') | |||
| if dataset.indexing_technique == 'high_quality': | |||
| raise NotFound("Document not found.") | |||
| if dataset.indexing_technique == "high_quality": | |||
| # check embedding model setting | |||
| try: | |||
| model_manager = ModelManager() | |||
| @@ -285,22 +279,22 @@ class DatasetDocumentSegmentUpdateApi(Resource): | |||
| tenant_id=current_user.current_tenant_id, | |||
| provider=dataset.embedding_model_provider, | |||
| model_type=ModelType.TEXT_EMBEDDING, | |||
| model=dataset.embedding_model | |||
| model=dataset.embedding_model, | |||
| ) | |||
| except LLMBadRequestError: | |||
| raise ProviderNotInitializeError( | |||
| "No Embedding Model available. Please configure a valid provider " | |||
| "in the Settings -> Model Provider.") | |||
| "in the Settings -> Model Provider." | |||
| ) | |||
| except ProviderTokenNotInitError as ex: | |||
| raise ProviderNotInitializeError(ex.description) | |||
| # check segment | |||
| segment_id = str(segment_id) | |||
| segment = DocumentSegment.query.filter( | |||
| DocumentSegment.id == str(segment_id), | |||
| DocumentSegment.tenant_id == current_user.current_tenant_id | |||
| DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id | |||
| ).first() | |||
| if not segment: | |||
| raise NotFound('Segment not found.') | |||
| raise NotFound("Segment not found.") | |||
| # The role of the current user in the ta table must be admin, owner, or editor | |||
| if not current_user.is_editor: | |||
| raise Forbidden() | |||
| @@ -310,16 +304,13 @@ class DatasetDocumentSegmentUpdateApi(Resource): | |||
| raise Forbidden(str(e)) | |||
| # validate args | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('content', type=str, required=True, nullable=False, location='json') | |||
| parser.add_argument('answer', type=str, required=False, nullable=True, location='json') | |||
| parser.add_argument('keywords', type=list, required=False, nullable=True, location='json') | |||
| parser.add_argument("content", type=str, required=True, nullable=False, location="json") | |||
| parser.add_argument("answer", type=str, required=False, nullable=True, location="json") | |||
| parser.add_argument("keywords", type=list, required=False, nullable=True, location="json") | |||
| args = parser.parse_args() | |||
| SegmentService.segment_create_args_validate(args, document) | |||
| segment = SegmentService.update_segment(args, segment, document, dataset) | |||
| return { | |||
| 'data': marshal(segment, segment_fields), | |||
| 'doc_form': document.doc_form | |||
| }, 200 | |||
| return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200 | |||
| @setup_required | |||
| @login_required | |||
| @@ -329,22 +320,21 @@ class DatasetDocumentSegmentUpdateApi(Resource): | |||
| dataset_id = str(dataset_id) | |||
| dataset = DatasetService.get_dataset(dataset_id) | |||
| if not dataset: | |||
| raise NotFound('Dataset not found.') | |||
| raise NotFound("Dataset not found.") | |||
| # check user's model setting | |||
| DatasetService.check_dataset_model_setting(dataset) | |||
| # check document | |||
| document_id = str(document_id) | |||
| document = DocumentService.get_document(dataset_id, document_id) | |||
| if not document: | |||
| raise NotFound('Document not found.') | |||
| raise NotFound("Document not found.") | |||
| # check segment | |||
| segment_id = str(segment_id) | |||
| segment = DocumentSegment.query.filter( | |||
| DocumentSegment.id == str(segment_id), | |||
| DocumentSegment.tenant_id == current_user.current_tenant_id | |||
| DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id | |||
| ).first() | |||
| if not segment: | |||
| raise NotFound('Segment not found.') | |||
| raise NotFound("Segment not found.") | |||
| # The role of the current user in the ta table must be admin or owner | |||
| if not current_user.is_editor: | |||
| raise Forbidden() | |||
| @@ -353,36 +343,36 @@ class DatasetDocumentSegmentUpdateApi(Resource): | |||
| except services.errors.account.NoPermissionError as e: | |||
| raise Forbidden(str(e)) | |||
| SegmentService.delete_segment(segment, document, dataset) | |||
| return {'result': 'success'}, 200 | |||
| return {"result": "success"}, 200 | |||
| class DatasetDocumentSegmentBatchImportApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @cloud_edition_billing_resource_check('vector_space') | |||
| @cloud_edition_billing_knowledge_limit_check('add_segment') | |||
| @cloud_edition_billing_resource_check("vector_space") | |||
| @cloud_edition_billing_knowledge_limit_check("add_segment") | |||
| def post(self, dataset_id, document_id): | |||
| # check dataset | |||
| dataset_id = str(dataset_id) | |||
| dataset = DatasetService.get_dataset(dataset_id) | |||
| if not dataset: | |||
| raise NotFound('Dataset not found.') | |||
| raise NotFound("Dataset not found.") | |||
| # check document | |||
| document_id = str(document_id) | |||
| document = DocumentService.get_document(dataset_id, document_id) | |||
| if not document: | |||
| raise NotFound('Document not found.') | |||
| raise NotFound("Document not found.") | |||
| # get file from request | |||
| file = request.files['file'] | |||
| file = request.files["file"] | |||
| # check file | |||
| if 'file' not in request.files: | |||
| if "file" not in request.files: | |||
| raise NoFileUploadedError() | |||
| if len(request.files) > 1: | |||
| raise TooManyFilesError() | |||
| # check file type | |||
| if not file.filename.endswith('.csv'): | |||
| if not file.filename.endswith(".csv"): | |||
| raise ValueError("Invalid file type. Only CSV files are allowed") | |||
| try: | |||
| @@ -390,51 +380,47 @@ class DatasetDocumentSegmentBatchImportApi(Resource): | |||
| df = pd.read_csv(file) | |||
| result = [] | |||
| for index, row in df.iterrows(): | |||
| if document.doc_form == 'qa_model': | |||
| data = {'content': row[0], 'answer': row[1]} | |||
| if document.doc_form == "qa_model": | |||
| data = {"content": row[0], "answer": row[1]} | |||
| else: | |||
| data = {'content': row[0]} | |||
| data = {"content": row[0]} | |||
| result.append(data) | |||
| if len(result) == 0: | |||
| raise ValueError("The CSV file is empty.") | |||
| # async job | |||
| job_id = str(uuid.uuid4()) | |||
| indexing_cache_key = 'segment_batch_import_{}'.format(str(job_id)) | |||
| indexing_cache_key = "segment_batch_import_{}".format(str(job_id)) | |||
| # send batch add segments task | |||
| redis_client.setnx(indexing_cache_key, 'waiting') | |||
| batch_create_segment_to_index_task.delay(str(job_id), result, dataset_id, document_id, | |||
| current_user.current_tenant_id, current_user.id) | |||
| redis_client.setnx(indexing_cache_key, "waiting") | |||
| batch_create_segment_to_index_task.delay( | |||
| str(job_id), result, dataset_id, document_id, current_user.current_tenant_id, current_user.id | |||
| ) | |||
| except Exception as e: | |||
| return {'error': str(e)}, 500 | |||
| return { | |||
| 'job_id': job_id, | |||
| 'job_status': 'waiting' | |||
| }, 200 | |||
| return {"error": str(e)}, 500 | |||
| return {"job_id": job_id, "job_status": "waiting"}, 200 | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| def get(self, job_id): | |||
| job_id = str(job_id) | |||
| indexing_cache_key = 'segment_batch_import_{}'.format(job_id) | |||
| indexing_cache_key = "segment_batch_import_{}".format(job_id) | |||
| cache_result = redis_client.get(indexing_cache_key) | |||
| if cache_result is None: | |||
| raise ValueError("The job is not exist.") | |||
| return { | |||
| 'job_id': job_id, | |||
| 'job_status': cache_result.decode() | |||
| }, 200 | |||
| return {"job_id": job_id, "job_status": cache_result.decode()}, 200 | |||
| api.add_resource(DatasetDocumentSegmentListApi, | |||
| '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments') | |||
| api.add_resource(DatasetDocumentSegmentApi, | |||
| '/datasets/<uuid:dataset_id>/segments/<uuid:segment_id>/<string:action>') | |||
| api.add_resource(DatasetDocumentSegmentAddApi, | |||
| '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segment') | |||
| api.add_resource(DatasetDocumentSegmentUpdateApi, | |||
| '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>') | |||
| api.add_resource(DatasetDocumentSegmentBatchImportApi, | |||
| '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/batch_import', | |||
| '/datasets/batch_import_status/<uuid:job_id>') | |||
| api.add_resource(DatasetDocumentSegmentListApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments") | |||
| api.add_resource(DatasetDocumentSegmentApi, "/datasets/<uuid:dataset_id>/segments/<uuid:segment_id>/<string:action>") | |||
| api.add_resource(DatasetDocumentSegmentAddApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segment") | |||
| api.add_resource( | |||
| DatasetDocumentSegmentUpdateApi, | |||
| "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>", | |||
| ) | |||
| api.add_resource( | |||
| DatasetDocumentSegmentBatchImportApi, | |||
| "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/batch_import", | |||
| "/datasets/batch_import_status/<uuid:job_id>", | |||
| ) | |||
| @@ -2,90 +2,90 @@ from libs.exception import BaseHTTPException | |||
| class NoFileUploadedError(BaseHTTPException): | |||
| error_code = 'no_file_uploaded' | |||
| error_code = "no_file_uploaded" | |||
| description = "Please upload your file." | |||
| code = 400 | |||
| class TooManyFilesError(BaseHTTPException): | |||
| error_code = 'too_many_files' | |||
| error_code = "too_many_files" | |||
| description = "Only one file is allowed." | |||
| code = 400 | |||
| class FileTooLargeError(BaseHTTPException): | |||
| error_code = 'file_too_large' | |||
| error_code = "file_too_large" | |||
| description = "File size exceeded. {message}" | |||
| code = 413 | |||
| class UnsupportedFileTypeError(BaseHTTPException): | |||
| error_code = 'unsupported_file_type' | |||
| error_code = "unsupported_file_type" | |||
| description = "File type not allowed." | |||
| code = 415 | |||
| class HighQualityDatasetOnlyError(BaseHTTPException): | |||
| error_code = 'high_quality_dataset_only' | |||
| error_code = "high_quality_dataset_only" | |||
| description = "Current operation only supports 'high-quality' datasets." | |||
| code = 400 | |||
| class DatasetNotInitializedError(BaseHTTPException): | |||
| error_code = 'dataset_not_initialized' | |||
| error_code = "dataset_not_initialized" | |||
| description = "The dataset is still being initialized or indexing. Please wait a moment." | |||
| code = 400 | |||
| class ArchivedDocumentImmutableError(BaseHTTPException): | |||
| error_code = 'archived_document_immutable' | |||
| error_code = "archived_document_immutable" | |||
| description = "The archived document is not editable." | |||
| code = 403 | |||
| class DatasetNameDuplicateError(BaseHTTPException): | |||
| error_code = 'dataset_name_duplicate' | |||
| error_code = "dataset_name_duplicate" | |||
| description = "The dataset name already exists. Please modify your dataset name." | |||
| code = 409 | |||
| class InvalidActionError(BaseHTTPException): | |||
| error_code = 'invalid_action' | |||
| error_code = "invalid_action" | |||
| description = "Invalid action." | |||
| code = 400 | |||
| class DocumentAlreadyFinishedError(BaseHTTPException): | |||
| error_code = 'document_already_finished' | |||
| error_code = "document_already_finished" | |||
| description = "The document has been processed. Please refresh the page or go to the document details." | |||
| code = 400 | |||
| class DocumentIndexingError(BaseHTTPException): | |||
| error_code = 'document_indexing' | |||
| error_code = "document_indexing" | |||
| description = "The document is being processed and cannot be edited." | |||
| code = 400 | |||
| class InvalidMetadataError(BaseHTTPException): | |||
| error_code = 'invalid_metadata' | |||
| error_code = "invalid_metadata" | |||
| description = "The metadata content is incorrect. Please check and verify." | |||
| code = 400 | |||
| class WebsiteCrawlError(BaseHTTPException): | |||
| error_code = 'crawl_failed' | |||
| error_code = "crawl_failed" | |||
| description = "{message}" | |||
| code = 500 | |||
| class DatasetInUseError(BaseHTTPException): | |||
| error_code = 'dataset_in_use' | |||
| error_code = "dataset_in_use" | |||
| description = "The dataset is being used by some apps. Please remove the dataset from the apps before deleting it." | |||
| code = 409 | |||
| class IndexingEstimateError(BaseHTTPException): | |||
| error_code = 'indexing_estimate_error' | |||
| error_code = "indexing_estimate_error" | |||
| description = "Knowledge indexing estimate failed: {message}" | |||
| code = 500 | |||
| @@ -21,7 +21,6 @@ PREVIEW_WORDS_LIMIT = 3000 | |||
| class FileApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @@ -31,23 +30,22 @@ class FileApi(Resource): | |||
| batch_count_limit = dify_config.UPLOAD_FILE_BATCH_LIMIT | |||
| image_file_size_limit = dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT | |||
| return { | |||
| 'file_size_limit': file_size_limit, | |||
| 'batch_count_limit': batch_count_limit, | |||
| 'image_file_size_limit': image_file_size_limit | |||
| "file_size_limit": file_size_limit, | |||
| "batch_count_limit": batch_count_limit, | |||
| "image_file_size_limit": image_file_size_limit, | |||
| }, 200 | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @marshal_with(file_fields) | |||
| @cloud_edition_billing_resource_check(resource='documents') | |||
| @cloud_edition_billing_resource_check(resource="documents") | |||
| def post(self): | |||
| # get file from request | |||
| file = request.files['file'] | |||
| file = request.files["file"] | |||
| # check file | |||
| if 'file' not in request.files: | |||
| if "file" not in request.files: | |||
| raise NoFileUploadedError() | |||
| if len(request.files) > 1: | |||
| @@ -69,7 +67,7 @@ class FilePreviewApi(Resource): | |||
| def get(self, file_id): | |||
| file_id = str(file_id) | |||
| text = FileService.get_file_preview(file_id) | |||
| return {'content': text} | |||
| return {"content": text} | |||
| class FileSupportTypeApi(Resource): | |||
| @@ -78,10 +76,10 @@ class FileSupportTypeApi(Resource): | |||
| @account_initialization_required | |||
| def get(self): | |||
| etl_type = dify_config.ETL_TYPE | |||
| allowed_extensions = UNSTRUCTURED_ALLOWED_EXTENSIONS if etl_type == 'Unstructured' else ALLOWED_EXTENSIONS | |||
| return {'allowed_extensions': allowed_extensions} | |||
| allowed_extensions = UNSTRUCTURED_ALLOWED_EXTENSIONS if etl_type == "Unstructured" else ALLOWED_EXTENSIONS | |||
| return {"allowed_extensions": allowed_extensions} | |||
| api.add_resource(FileApi, '/files/upload') | |||
| api.add_resource(FilePreviewApi, '/files/<uuid:file_id>/preview') | |||
| api.add_resource(FileSupportTypeApi, '/files/support-type') | |||
| api.add_resource(FileApi, "/files/upload") | |||
| api.add_resource(FilePreviewApi, "/files/<uuid:file_id>/preview") | |||
| api.add_resource(FileSupportTypeApi, "/files/support-type") | |||
| @@ -29,7 +29,6 @@ from services.hit_testing_service import HitTestingService | |||
| class HitTestingApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @@ -46,8 +45,8 @@ class HitTestingApi(Resource): | |||
| raise Forbidden(str(e)) | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('query', type=str, location='json') | |||
| parser.add_argument('retrieval_model', type=dict, required=False, location='json') | |||
| parser.add_argument("query", type=str, location="json") | |||
| parser.add_argument("retrieval_model", type=dict, required=False, location="json") | |||
| args = parser.parse_args() | |||
| HitTestingService.hit_testing_args_check(args) | |||
| @@ -55,13 +54,13 @@ class HitTestingApi(Resource): | |||
| try: | |||
| response = HitTestingService.retrieve( | |||
| dataset=dataset, | |||
| query=args['query'], | |||
| query=args["query"], | |||
| account=current_user, | |||
| retrieval_model=args['retrieval_model'], | |||
| limit=10 | |||
| retrieval_model=args["retrieval_model"], | |||
| limit=10, | |||
| ) | |||
| return {"query": response['query'], 'records': marshal(response['records'], hit_testing_record_fields)} | |||
| return {"query": response["query"], "records": marshal(response["records"], hit_testing_record_fields)} | |||
| except services.errors.index.IndexNotInitializedError: | |||
| raise DatasetNotInitializedError() | |||
| except ProviderTokenNotInitError as ex: | |||
| @@ -73,7 +72,8 @@ class HitTestingApi(Resource): | |||
| except LLMBadRequestError: | |||
| raise ProviderNotInitializeError( | |||
| "No Embedding Model or Reranking Model available. Please configure a valid provider " | |||
| "in the Settings -> Model Provider.") | |||
| "in the Settings -> Model Provider." | |||
| ) | |||
| except InvokeError as e: | |||
| raise CompletionRequestError(e.description) | |||
| except ValueError as e: | |||
| @@ -83,4 +83,4 @@ class HitTestingApi(Resource): | |||
| raise InternalServerError(str(e)) | |||
| api.add_resource(HitTestingApi, '/datasets/<uuid:dataset_id>/hit-testing') | |||
| api.add_resource(HitTestingApi, "/datasets/<uuid:dataset_id>/hit-testing") | |||
| @@ -9,16 +9,14 @@ from services.website_service import WebsiteService | |||
| class WebsiteCrawlApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| def post(self): | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('provider', type=str, choices=['firecrawl'], | |||
| required=True, nullable=True, location='json') | |||
| parser.add_argument('url', type=str, required=True, nullable=True, location='json') | |||
| parser.add_argument('options', type=dict, required=True, nullable=True, location='json') | |||
| parser.add_argument("provider", type=str, choices=["firecrawl"], required=True, nullable=True, location="json") | |||
| parser.add_argument("url", type=str, required=True, nullable=True, location="json") | |||
| parser.add_argument("options", type=dict, required=True, nullable=True, location="json") | |||
| args = parser.parse_args() | |||
| WebsiteService.document_create_args_validate(args) | |||
| # crawl url | |||
| @@ -35,15 +33,15 @@ class WebsiteCrawlStatusApi(Resource): | |||
| @account_initialization_required | |||
| def get(self, job_id: str): | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('provider', type=str, choices=['firecrawl'], required=True, location='args') | |||
| parser.add_argument("provider", type=str, choices=["firecrawl"], required=True, location="args") | |||
| args = parser.parse_args() | |||
| # get crawl status | |||
| try: | |||
| result = WebsiteService.get_crawl_status(job_id, args['provider']) | |||
| result = WebsiteService.get_crawl_status(job_id, args["provider"]) | |||
| except Exception as e: | |||
| raise WebsiteCrawlError(str(e)) | |||
| return result, 200 | |||
| api.add_resource(WebsiteCrawlApi, '/website/crawl') | |||
| api.add_resource(WebsiteCrawlStatusApi, '/website/crawl/status/<string:job_id>') | |||
| api.add_resource(WebsiteCrawlApi, "/website/crawl") | |||
| api.add_resource(WebsiteCrawlStatusApi, "/website/crawl/status/<string:job_id>") | |||
| @@ -2,35 +2,41 @@ from libs.exception import BaseHTTPException | |||
| class AlreadySetupError(BaseHTTPException): | |||
| error_code = 'already_setup' | |||
| error_code = "already_setup" | |||
| description = "Dify has been successfully installed. Please refresh the page or return to the dashboard homepage." | |||
| code = 403 | |||
| class NotSetupError(BaseHTTPException): | |||
| error_code = 'not_setup' | |||
| description = "Dify has not been initialized and installed yet. " \ | |||
| "Please proceed with the initialization and installation process first." | |||
| error_code = "not_setup" | |||
| description = ( | |||
| "Dify has not been initialized and installed yet. " | |||
| "Please proceed with the initialization and installation process first." | |||
| ) | |||
| code = 401 | |||
| class NotInitValidateError(BaseHTTPException): | |||
| error_code = 'not_init_validated' | |||
| description = "Init validation has not been completed yet. " \ | |||
| "Please proceed with the init validation process first." | |||
| error_code = "not_init_validated" | |||
| description = ( | |||
| "Init validation has not been completed yet. " "Please proceed with the init validation process first." | |||
| ) | |||
| code = 401 | |||
| class InitValidateFailedError(BaseHTTPException): | |||
| error_code = 'init_validate_failed' | |||
| error_code = "init_validate_failed" | |||
| description = "Init validation failed. Please check the password and try again." | |||
| code = 401 | |||
| class AccountNotLinkTenantError(BaseHTTPException): | |||
| error_code = 'account_not_link_tenant' | |||
| error_code = "account_not_link_tenant" | |||
| description = "Account not link tenant." | |||
| code = 403 | |||
| class AlreadyActivateError(BaseHTTPException): | |||
| error_code = 'already_activate' | |||
| error_code = "already_activate" | |||
| description = "Auth Token is invalid or account already activated, please check again." | |||
| code = 403 | |||
| @@ -33,14 +33,10 @@ class ChatAudioApi(InstalledAppResource): | |||
| def post(self, installed_app): | |||
| app_model = installed_app.app | |||
| file = request.files['file'] | |||
| file = request.files["file"] | |||
| try: | |||
| response = AudioService.transcript_asr( | |||
| app_model=app_model, | |||
| file=file, | |||
| end_user=None | |||
| ) | |||
| response = AudioService.transcript_asr(app_model=app_model, file=file, end_user=None) | |||
| return response | |||
| except services.errors.app_model_config.AppModelConfigBrokenError: | |||
| @@ -76,30 +72,31 @@ class ChatTextApi(InstalledAppResource): | |||
| app_model = installed_app.app | |||
| try: | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('message_id', type=str, required=False, location='json') | |||
| parser.add_argument('voice', type=str, location='json') | |||
| parser.add_argument('text', type=str, location='json') | |||
| parser.add_argument('streaming', type=bool, location='json') | |||
| parser.add_argument("message_id", type=str, required=False, location="json") | |||
| parser.add_argument("voice", type=str, location="json") | |||
| parser.add_argument("text", type=str, location="json") | |||
| parser.add_argument("streaming", type=bool, location="json") | |||
| args = parser.parse_args() | |||
| message_id = args.get('message_id', None) | |||
| text = args.get('text', None) | |||
| if (app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value] | |||
| and app_model.workflow | |||
| and app_model.workflow.features_dict): | |||
| text_to_speech = app_model.workflow.features_dict.get('text_to_speech') | |||
| voice = args.get('voice') if args.get('voice') else text_to_speech.get('voice') | |||
| message_id = args.get("message_id", None) | |||
| text = args.get("text", None) | |||
| if ( | |||
| app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value] | |||
| and app_model.workflow | |||
| and app_model.workflow.features_dict | |||
| ): | |||
| text_to_speech = app_model.workflow.features_dict.get("text_to_speech") | |||
| voice = args.get("voice") if args.get("voice") else text_to_speech.get("voice") | |||
| else: | |||
| try: | |||
| voice = args.get('voice') if args.get('voice') else app_model.app_model_config.text_to_speech_dict.get('voice') | |||
| voice = ( | |||
| args.get("voice") | |||
| if args.get("voice") | |||
| else app_model.app_model_config.text_to_speech_dict.get("voice") | |||
| ) | |||
| except Exception: | |||
| voice = None | |||
| response = AudioService.transcript_tts( | |||
| app_model=app_model, | |||
| message_id=message_id, | |||
| voice=voice, | |||
| text=text | |||
| ) | |||
| response = AudioService.transcript_tts(app_model=app_model, message_id=message_id, voice=voice, text=text) | |||
| return response | |||
| except services.errors.app_model_config.AppModelConfigBrokenError: | |||
| logging.exception("App model config broken.") | |||
| @@ -127,7 +124,7 @@ class ChatTextApi(InstalledAppResource): | |||
| raise InternalServerError() | |||
| api.add_resource(ChatAudioApi, '/installed-apps/<uuid:installed_app_id>/audio-to-text', endpoint='installed_app_audio') | |||
| api.add_resource(ChatTextApi, '/installed-apps/<uuid:installed_app_id>/text-to-audio', endpoint='installed_app_text') | |||
| api.add_resource(ChatAudioApi, "/installed-apps/<uuid:installed_app_id>/audio-to-text", endpoint="installed_app_audio") | |||
| api.add_resource(ChatTextApi, "/installed-apps/<uuid:installed_app_id>/text-to-audio", endpoint="installed_app_text") | |||
| # api.add_resource(ChatTextApiWithMessageId, '/installed-apps/<uuid:installed_app_id>/text-to-audio/message-id', | |||
| # endpoint='installed_app_text_with_message_id') | |||
| @@ -30,33 +30,28 @@ from services.app_generate_service import AppGenerateService | |||
| # define completion api for user | |||
| class CompletionApi(InstalledAppResource): | |||
| def post(self, installed_app): | |||
| app_model = installed_app.app | |||
| if app_model.mode != 'completion': | |||
| if app_model.mode != "completion": | |||
| raise NotCompletionAppError() | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('inputs', type=dict, required=True, location='json') | |||
| parser.add_argument('query', type=str, location='json', default='') | |||
| parser.add_argument('files', type=list, required=False, location='json') | |||
| parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json') | |||
| parser.add_argument('retriever_from', type=str, required=False, default='explore_app', location='json') | |||
| parser.add_argument("inputs", type=dict, required=True, location="json") | |||
| parser.add_argument("query", type=str, location="json", default="") | |||
| parser.add_argument("files", type=list, required=False, location="json") | |||
| parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json") | |||
| parser.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json") | |||
| args = parser.parse_args() | |||
| streaming = args['response_mode'] == 'streaming' | |||
| args['auto_generate_name'] = False | |||
| streaming = args["response_mode"] == "streaming" | |||
| args["auto_generate_name"] = False | |||
| installed_app.last_used_at = datetime.now(timezone.utc).replace(tzinfo=None) | |||
| db.session.commit() | |||
| try: | |||
| response = AppGenerateService.generate( | |||
| app_model=app_model, | |||
| user=current_user, | |||
| args=args, | |||
| invoke_from=InvokeFrom.EXPLORE, | |||
| streaming=streaming | |||
| app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=streaming | |||
| ) | |||
| return helper.compact_generate_response(response) | |||
| @@ -85,12 +80,12 @@ class CompletionApi(InstalledAppResource): | |||
| class CompletionStopApi(InstalledAppResource): | |||
| def post(self, installed_app, task_id): | |||
| app_model = installed_app.app | |||
| if app_model.mode != 'completion': | |||
| if app_model.mode != "completion": | |||
| raise NotCompletionAppError() | |||
| AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id) | |||
| return {'result': 'success'}, 200 | |||
| return {"result": "success"}, 200 | |||
| class ChatApi(InstalledAppResource): | |||
| @@ -101,25 +96,21 @@ class ChatApi(InstalledAppResource): | |||
| raise NotChatAppError() | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('inputs', type=dict, required=True, location='json') | |||
| parser.add_argument('query', type=str, required=True, location='json') | |||
| parser.add_argument('files', type=list, required=False, location='json') | |||
| parser.add_argument('conversation_id', type=uuid_value, location='json') | |||
| parser.add_argument('retriever_from', type=str, required=False, default='explore_app', location='json') | |||
| parser.add_argument("inputs", type=dict, required=True, location="json") | |||
| parser.add_argument("query", type=str, required=True, location="json") | |||
| parser.add_argument("files", type=list, required=False, location="json") | |||
| parser.add_argument("conversation_id", type=uuid_value, location="json") | |||
| parser.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json") | |||
| args = parser.parse_args() | |||
| args['auto_generate_name'] = False | |||
| args["auto_generate_name"] = False | |||
| installed_app.last_used_at = datetime.now(timezone.utc).replace(tzinfo=None) | |||
| db.session.commit() | |||
| try: | |||
| response = AppGenerateService.generate( | |||
| app_model=app_model, | |||
| user=current_user, | |||
| args=args, | |||
| invoke_from=InvokeFrom.EXPLORE, | |||
| streaming=True | |||
| app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=True | |||
| ) | |||
| return helper.compact_generate_response(response) | |||
| @@ -154,10 +145,22 @@ class ChatStopApi(InstalledAppResource): | |||
| AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id) | |||
| return {'result': 'success'}, 200 | |||
| return {"result": "success"}, 200 | |||
| api.add_resource(CompletionApi, '/installed-apps/<uuid:installed_app_id>/completion-messages', endpoint='installed_app_completion') | |||
| api.add_resource(CompletionStopApi, '/installed-apps/<uuid:installed_app_id>/completion-messages/<string:task_id>/stop', endpoint='installed_app_stop_completion') | |||
| api.add_resource(ChatApi, '/installed-apps/<uuid:installed_app_id>/chat-messages', endpoint='installed_app_chat_completion') | |||
| api.add_resource(ChatStopApi, '/installed-apps/<uuid:installed_app_id>/chat-messages/<string:task_id>/stop', endpoint='installed_app_stop_chat_completion') | |||
| api.add_resource( | |||
| CompletionApi, "/installed-apps/<uuid:installed_app_id>/completion-messages", endpoint="installed_app_completion" | |||
| ) | |||
| api.add_resource( | |||
| CompletionStopApi, | |||
| "/installed-apps/<uuid:installed_app_id>/completion-messages/<string:task_id>/stop", | |||
| endpoint="installed_app_stop_completion", | |||
| ) | |||
| api.add_resource( | |||
| ChatApi, "/installed-apps/<uuid:installed_app_id>/chat-messages", endpoint="installed_app_chat_completion" | |||
| ) | |||
| api.add_resource( | |||
| ChatStopApi, | |||
| "/installed-apps/<uuid:installed_app_id>/chat-messages/<string:task_id>/stop", | |||
| endpoint="installed_app_stop_chat_completion", | |||
| ) | |||
| @@ -16,7 +16,6 @@ from services.web_conversation_service import WebConversationService | |||
| class ConversationListApi(InstalledAppResource): | |||
| @marshal_with(conversation_infinite_scroll_pagination_fields) | |||
| def get(self, installed_app): | |||
| app_model = installed_app.app | |||
| @@ -25,21 +24,21 @@ class ConversationListApi(InstalledAppResource): | |||
| raise NotChatAppError() | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('last_id', type=uuid_value, location='args') | |||
| parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args') | |||
| parser.add_argument('pinned', type=str, choices=['true', 'false', None], location='args') | |||
| parser.add_argument("last_id", type=uuid_value, location="args") | |||
| parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") | |||
| parser.add_argument("pinned", type=str, choices=["true", "false", None], location="args") | |||
| args = parser.parse_args() | |||
| pinned = None | |||
| if 'pinned' in args and args['pinned'] is not None: | |||
| pinned = True if args['pinned'] == 'true' else False | |||
| if "pinned" in args and args["pinned"] is not None: | |||
| pinned = True if args["pinned"] == "true" else False | |||
| try: | |||
| return WebConversationService.pagination_by_last_id( | |||
| app_model=app_model, | |||
| user=current_user, | |||
| last_id=args['last_id'], | |||
| limit=args['limit'], | |||
| last_id=args["last_id"], | |||
| limit=args["limit"], | |||
| invoke_from=InvokeFrom.EXPLORE, | |||
| pinned=pinned, | |||
| ) | |||
| @@ -65,7 +64,6 @@ class ConversationApi(InstalledAppResource): | |||
| class ConversationRenameApi(InstalledAppResource): | |||
| @marshal_with(simple_conversation_fields) | |||
| def post(self, installed_app, c_id): | |||
| app_model = installed_app.app | |||
| @@ -76,24 +74,19 @@ class ConversationRenameApi(InstalledAppResource): | |||
| conversation_id = str(c_id) | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('name', type=str, required=False, location='json') | |||
| parser.add_argument('auto_generate', type=bool, required=False, default=False, location='json') | |||
| parser.add_argument("name", type=str, required=False, location="json") | |||
| parser.add_argument("auto_generate", type=bool, required=False, default=False, location="json") | |||
| args = parser.parse_args() | |||
| try: | |||
| return ConversationService.rename( | |||
| app_model, | |||
| conversation_id, | |||
| current_user, | |||
| args['name'], | |||
| args['auto_generate'] | |||
| app_model, conversation_id, current_user, args["name"], args["auto_generate"] | |||
| ) | |||
| except ConversationNotExistsError: | |||
| raise NotFound("Conversation Not Exists.") | |||
| class ConversationPinApi(InstalledAppResource): | |||
| def patch(self, installed_app, c_id): | |||
| app_model = installed_app.app | |||
| app_mode = AppMode.value_of(app_model.mode) | |||
| @@ -123,8 +116,26 @@ class ConversationUnPinApi(InstalledAppResource): | |||
| return {"result": "success"} | |||
| api.add_resource(ConversationRenameApi, '/installed-apps/<uuid:installed_app_id>/conversations/<uuid:c_id>/name', endpoint='installed_app_conversation_rename') | |||
| api.add_resource(ConversationListApi, '/installed-apps/<uuid:installed_app_id>/conversations', endpoint='installed_app_conversations') | |||
| api.add_resource(ConversationApi, '/installed-apps/<uuid:installed_app_id>/conversations/<uuid:c_id>', endpoint='installed_app_conversation') | |||
| api.add_resource(ConversationPinApi, '/installed-apps/<uuid:installed_app_id>/conversations/<uuid:c_id>/pin', endpoint='installed_app_conversation_pin') | |||
| api.add_resource(ConversationUnPinApi, '/installed-apps/<uuid:installed_app_id>/conversations/<uuid:c_id>/unpin', endpoint='installed_app_conversation_unpin') | |||
| api.add_resource( | |||
| ConversationRenameApi, | |||
| "/installed-apps/<uuid:installed_app_id>/conversations/<uuid:c_id>/name", | |||
| endpoint="installed_app_conversation_rename", | |||
| ) | |||
| api.add_resource( | |||
| ConversationListApi, "/installed-apps/<uuid:installed_app_id>/conversations", endpoint="installed_app_conversations" | |||
| ) | |||
| api.add_resource( | |||
| ConversationApi, | |||
| "/installed-apps/<uuid:installed_app_id>/conversations/<uuid:c_id>", | |||
| endpoint="installed_app_conversation", | |||
| ) | |||
| api.add_resource( | |||
| ConversationPinApi, | |||
| "/installed-apps/<uuid:installed_app_id>/conversations/<uuid:c_id>/pin", | |||
| endpoint="installed_app_conversation_pin", | |||
| ) | |||
| api.add_resource( | |||
| ConversationUnPinApi, | |||
| "/installed-apps/<uuid:installed_app_id>/conversations/<uuid:c_id>/unpin", | |||
| endpoint="installed_app_conversation_unpin", | |||
| ) | |||
| @@ -2,24 +2,24 @@ from libs.exception import BaseHTTPException | |||
| class NotCompletionAppError(BaseHTTPException): | |||
| error_code = 'not_completion_app' | |||
| error_code = "not_completion_app" | |||
| description = "Not Completion App" | |||
| code = 400 | |||
| class NotChatAppError(BaseHTTPException): | |||
| error_code = 'not_chat_app' | |||
| error_code = "not_chat_app" | |||
| description = "App mode is invalid." | |||
| code = 400 | |||
| class NotWorkflowAppError(BaseHTTPException): | |||
| error_code = 'not_workflow_app' | |||
| error_code = "not_workflow_app" | |||
| description = "Only support workflow app." | |||
| code = 400 | |||
| class AppSuggestedQuestionsAfterAnswerDisabledError(BaseHTTPException): | |||
| error_code = 'app_suggested_questions_after_answer_disabled' | |||
| error_code = "app_suggested_questions_after_answer_disabled" | |||
| description = "Function Suggested questions after answer disabled." | |||
| code = 403 | |||
| @@ -21,72 +21,71 @@ class InstalledAppsListApi(Resource): | |||
| @marshal_with(installed_app_list_fields) | |||
| def get(self): | |||
| current_tenant_id = current_user.current_tenant_id | |||
| installed_apps = db.session.query(InstalledApp).filter( | |||
| InstalledApp.tenant_id == current_tenant_id | |||
| ).all() | |||
| installed_apps = db.session.query(InstalledApp).filter(InstalledApp.tenant_id == current_tenant_id).all() | |||
| current_user.role = TenantService.get_user_role(current_user, current_user.current_tenant) | |||
| installed_apps = [ | |||
| { | |||
| 'id': installed_app.id, | |||
| 'app': installed_app.app, | |||
| 'app_owner_tenant_id': installed_app.app_owner_tenant_id, | |||
| 'is_pinned': installed_app.is_pinned, | |||
| 'last_used_at': installed_app.last_used_at, | |||
| 'editable': current_user.role in ["owner", "admin"], | |||
| 'uninstallable': current_tenant_id == installed_app.app_owner_tenant_id | |||
| "id": installed_app.id, | |||
| "app": installed_app.app, | |||
| "app_owner_tenant_id": installed_app.app_owner_tenant_id, | |||
| "is_pinned": installed_app.is_pinned, | |||
| "last_used_at": installed_app.last_used_at, | |||
| "editable": current_user.role in ["owner", "admin"], | |||
| "uninstallable": current_tenant_id == installed_app.app_owner_tenant_id, | |||
| } | |||
| for installed_app in installed_apps | |||
| ] | |||
| installed_apps.sort(key=lambda app: (-app['is_pinned'], | |||
| app['last_used_at'] is None, | |||
| -app['last_used_at'].timestamp() if app['last_used_at'] is not None else 0)) | |||
| installed_apps.sort( | |||
| key=lambda app: ( | |||
| -app["is_pinned"], | |||
| app["last_used_at"] is None, | |||
| -app["last_used_at"].timestamp() if app["last_used_at"] is not None else 0, | |||
| ) | |||
| ) | |||
| return {'installed_apps': installed_apps} | |||
| return {"installed_apps": installed_apps} | |||
| @login_required | |||
| @account_initialization_required | |||
| @cloud_edition_billing_resource_check('apps') | |||
| @cloud_edition_billing_resource_check("apps") | |||
| def post(self): | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('app_id', type=str, required=True, help='Invalid app_id') | |||
| parser.add_argument("app_id", type=str, required=True, help="Invalid app_id") | |||
| args = parser.parse_args() | |||
| recommended_app = RecommendedApp.query.filter(RecommendedApp.app_id == args['app_id']).first() | |||
| recommended_app = RecommendedApp.query.filter(RecommendedApp.app_id == args["app_id"]).first() | |||
| if recommended_app is None: | |||
| raise NotFound('App not found') | |||
| raise NotFound("App not found") | |||
| current_tenant_id = current_user.current_tenant_id | |||
| app = db.session.query(App).filter( | |||
| App.id == args['app_id'] | |||
| ).first() | |||
| app = db.session.query(App).filter(App.id == args["app_id"]).first() | |||
| if app is None: | |||
| raise NotFound('App not found') | |||
| raise NotFound("App not found") | |||
| if not app.is_public: | |||
| raise Forbidden('You can\'t install a non-public app') | |||
| raise Forbidden("You can't install a non-public app") | |||
| installed_app = InstalledApp.query.filter(and_( | |||
| InstalledApp.app_id == args['app_id'], | |||
| InstalledApp.tenant_id == current_tenant_id | |||
| )).first() | |||
| installed_app = InstalledApp.query.filter( | |||
| and_(InstalledApp.app_id == args["app_id"], InstalledApp.tenant_id == current_tenant_id) | |||
| ).first() | |||
| if installed_app is None: | |||
| # todo: position | |||
| recommended_app.install_count += 1 | |||
| new_installed_app = InstalledApp( | |||
| app_id=args['app_id'], | |||
| app_id=args["app_id"], | |||
| tenant_id=current_tenant_id, | |||
| app_owner_tenant_id=app.tenant_id, | |||
| is_pinned=False, | |||
| last_used_at=datetime.now(timezone.utc).replace(tzinfo=None) | |||
| last_used_at=datetime.now(timezone.utc).replace(tzinfo=None), | |||
| ) | |||
| db.session.add(new_installed_app) | |||
| db.session.commit() | |||
| return {'message': 'App installed successfully'} | |||
| return {"message": "App installed successfully"} | |||
| class InstalledAppApi(InstalledAppResource): | |||
| @@ -94,30 +93,31 @@ class InstalledAppApi(InstalledAppResource): | |||
| update and delete an installed app | |||
| use InstalledAppResource to apply default decorators and get installed_app | |||
| """ | |||
| def delete(self, installed_app): | |||
| if installed_app.app_owner_tenant_id == current_user.current_tenant_id: | |||
| raise BadRequest('You can\'t uninstall an app owned by the current tenant') | |||
| raise BadRequest("You can't uninstall an app owned by the current tenant") | |||
| db.session.delete(installed_app) | |||
| db.session.commit() | |||
| return {'result': 'success', 'message': 'App uninstalled successfully'} | |||
| return {"result": "success", "message": "App uninstalled successfully"} | |||
| def patch(self, installed_app): | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('is_pinned', type=inputs.boolean) | |||
| parser.add_argument("is_pinned", type=inputs.boolean) | |||
| args = parser.parse_args() | |||
| commit_args = False | |||
| if 'is_pinned' in args: | |||
| installed_app.is_pinned = args['is_pinned'] | |||
| if "is_pinned" in args: | |||
| installed_app.is_pinned = args["is_pinned"] | |||
| commit_args = True | |||
| if commit_args: | |||
| db.session.commit() | |||
| return {'result': 'success', 'message': 'App info updated successfully'} | |||
| return {"result": "success", "message": "App info updated successfully"} | |||
| api.add_resource(InstalledAppsListApi, '/installed-apps') | |||
| api.add_resource(InstalledAppApi, '/installed-apps/<uuid:installed_app_id>') | |||
| api.add_resource(InstalledAppsListApi, "/installed-apps") | |||
| api.add_resource(InstalledAppApi, "/installed-apps/<uuid:installed_app_id>") | |||
| @@ -44,19 +44,21 @@ class MessageListApi(InstalledAppResource): | |||
| raise NotChatAppError() | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('conversation_id', required=True, type=uuid_value, location='args') | |||
| parser.add_argument('first_id', type=uuid_value, location='args') | |||
| parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args') | |||
| parser.add_argument("conversation_id", required=True, type=uuid_value, location="args") | |||
| parser.add_argument("first_id", type=uuid_value, location="args") | |||
| parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") | |||
| args = parser.parse_args() | |||
| try: | |||
| return MessageService.pagination_by_first_id(app_model, current_user, | |||
| args['conversation_id'], args['first_id'], args['limit']) | |||
| return MessageService.pagination_by_first_id( | |||
| app_model, current_user, args["conversation_id"], args["first_id"], args["limit"] | |||
| ) | |||
| except services.errors.conversation.ConversationNotExistsError: | |||
| raise NotFound("Conversation Not Exists.") | |||
| except services.errors.message.FirstMessageNotExistsError: | |||
| raise NotFound("First Message Not Exists.") | |||
| class MessageFeedbackApi(InstalledAppResource): | |||
| def post(self, installed_app, message_id): | |||
| app_model = installed_app.app | |||
| @@ -64,30 +66,32 @@ class MessageFeedbackApi(InstalledAppResource): | |||
| message_id = str(message_id) | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('rating', type=str, choices=['like', 'dislike', None], location='json') | |||
| parser.add_argument("rating", type=str, choices=["like", "dislike", None], location="json") | |||
| args = parser.parse_args() | |||
| try: | |||
| MessageService.create_feedback(app_model, message_id, current_user, args['rating']) | |||
| MessageService.create_feedback(app_model, message_id, current_user, args["rating"]) | |||
| except services.errors.message.MessageNotExistsError: | |||
| raise NotFound("Message Not Exists.") | |||
| return {'result': 'success'} | |||
| return {"result": "success"} | |||
| class MessageMoreLikeThisApi(InstalledAppResource): | |||
| def get(self, installed_app, message_id): | |||
| app_model = installed_app.app | |||
| if app_model.mode != 'completion': | |||
| if app_model.mode != "completion": | |||
| raise NotCompletionAppError() | |||
| message_id = str(message_id) | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('response_mode', type=str, required=True, choices=['blocking', 'streaming'], location='args') | |||
| parser.add_argument( | |||
| "response_mode", type=str, required=True, choices=["blocking", "streaming"], location="args" | |||
| ) | |||
| args = parser.parse_args() | |||
| streaming = args['response_mode'] == 'streaming' | |||
| streaming = args["response_mode"] == "streaming" | |||
| try: | |||
| response = AppGenerateService.generate_more_like_this( | |||
| @@ -95,7 +99,7 @@ class MessageMoreLikeThisApi(InstalledAppResource): | |||
| user=current_user, | |||
| message_id=message_id, | |||
| invoke_from=InvokeFrom.EXPLORE, | |||
| streaming=streaming | |||
| streaming=streaming, | |||
| ) | |||
| return helper.compact_generate_response(response) | |||
| except MessageNotExistsError: | |||
| @@ -128,10 +132,7 @@ class MessageSuggestedQuestionApi(InstalledAppResource): | |||
| try: | |||
| questions = MessageService.get_suggested_questions_after_answer( | |||
| app_model=app_model, | |||
| user=current_user, | |||
| message_id=message_id, | |||
| invoke_from=InvokeFrom.EXPLORE | |||
| app_model=app_model, user=current_user, message_id=message_id, invoke_from=InvokeFrom.EXPLORE | |||
| ) | |||
| except MessageNotExistsError: | |||
| raise NotFound("Message not found") | |||
| @@ -151,10 +152,22 @@ class MessageSuggestedQuestionApi(InstalledAppResource): | |||
| logging.exception("internal server error.") | |||
| raise InternalServerError() | |||
| return {'data': questions} | |||
| return {"data": questions} | |||
| api.add_resource(MessageListApi, '/installed-apps/<uuid:installed_app_id>/messages', endpoint='installed_app_messages') | |||
| api.add_resource(MessageFeedbackApi, '/installed-apps/<uuid:installed_app_id>/messages/<uuid:message_id>/feedbacks', endpoint='installed_app_message_feedback') | |||
| api.add_resource(MessageMoreLikeThisApi, '/installed-apps/<uuid:installed_app_id>/messages/<uuid:message_id>/more-like-this', endpoint='installed_app_more_like_this') | |||
| api.add_resource(MessageSuggestedQuestionApi, '/installed-apps/<uuid:installed_app_id>/messages/<uuid:message_id>/suggested-questions', endpoint='installed_app_suggested_question') | |||
| api.add_resource(MessageListApi, "/installed-apps/<uuid:installed_app_id>/messages", endpoint="installed_app_messages") | |||
| api.add_resource( | |||
| MessageFeedbackApi, | |||
| "/installed-apps/<uuid:installed_app_id>/messages/<uuid:message_id>/feedbacks", | |||
| endpoint="installed_app_message_feedback", | |||
| ) | |||
| api.add_resource( | |||
| MessageMoreLikeThisApi, | |||
| "/installed-apps/<uuid:installed_app_id>/messages/<uuid:message_id>/more-like-this", | |||
| endpoint="installed_app_more_like_this", | |||
| ) | |||
| api.add_resource( | |||
| MessageSuggestedQuestionApi, | |||
| "/installed-apps/<uuid:installed_app_id>/messages/<uuid:message_id>/suggested-questions", | |||
| endpoint="installed_app_suggested_question", | |||
| ) | |||
| @@ -1,4 +1,3 @@ | |||
| from flask_restful import fields, marshal_with | |||
| from configs import dify_config | |||
| @@ -11,33 +10,32 @@ from services.app_service import AppService | |||
| class AppParameterApi(InstalledAppResource): | |||
| """Resource for app variables.""" | |||
| variable_fields = { | |||
| 'key': fields.String, | |||
| 'name': fields.String, | |||
| 'description': fields.String, | |||
| 'type': fields.String, | |||
| 'default': fields.String, | |||
| 'max_length': fields.Integer, | |||
| 'options': fields.List(fields.String) | |||
| "key": fields.String, | |||
| "name": fields.String, | |||
| "description": fields.String, | |||
| "type": fields.String, | |||
| "default": fields.String, | |||
| "max_length": fields.Integer, | |||
| "options": fields.List(fields.String), | |||
| } | |||
| system_parameters_fields = { | |||
| 'image_file_size_limit': fields.String | |||
| } | |||
| system_parameters_fields = {"image_file_size_limit": fields.String} | |||
| parameters_fields = { | |||
| 'opening_statement': fields.String, | |||
| 'suggested_questions': fields.Raw, | |||
| 'suggested_questions_after_answer': fields.Raw, | |||
| 'speech_to_text': fields.Raw, | |||
| 'text_to_speech': fields.Raw, | |||
| 'retriever_resource': fields.Raw, | |||
| 'annotation_reply': fields.Raw, | |||
| 'more_like_this': fields.Raw, | |||
| 'user_input_form': fields.Raw, | |||
| 'sensitive_word_avoidance': fields.Raw, | |||
| 'file_upload': fields.Raw, | |||
| 'system_parameters': fields.Nested(system_parameters_fields) | |||
| "opening_statement": fields.String, | |||
| "suggested_questions": fields.Raw, | |||
| "suggested_questions_after_answer": fields.Raw, | |||
| "speech_to_text": fields.Raw, | |||
| "text_to_speech": fields.Raw, | |||
| "retriever_resource": fields.Raw, | |||
| "annotation_reply": fields.Raw, | |||
| "more_like_this": fields.Raw, | |||
| "user_input_form": fields.Raw, | |||
| "sensitive_word_avoidance": fields.Raw, | |||
| "file_upload": fields.Raw, | |||
| "system_parameters": fields.Nested(system_parameters_fields), | |||
| } | |||
| @marshal_with(parameters_fields) | |||
| @@ -56,30 +54,35 @@ class AppParameterApi(InstalledAppResource): | |||
| app_model_config = app_model.app_model_config | |||
| features_dict = app_model_config.to_dict() | |||
| user_input_form = features_dict.get('user_input_form', []) | |||
| user_input_form = features_dict.get("user_input_form", []) | |||
| return { | |||
| 'opening_statement': features_dict.get('opening_statement'), | |||
| 'suggested_questions': features_dict.get('suggested_questions', []), | |||
| 'suggested_questions_after_answer': features_dict.get('suggested_questions_after_answer', | |||
| {"enabled": False}), | |||
| 'speech_to_text': features_dict.get('speech_to_text', {"enabled": False}), | |||
| 'text_to_speech': features_dict.get('text_to_speech', {"enabled": False}), | |||
| 'retriever_resource': features_dict.get('retriever_resource', {"enabled": False}), | |||
| 'annotation_reply': features_dict.get('annotation_reply', {"enabled": False}), | |||
| 'more_like_this': features_dict.get('more_like_this', {"enabled": False}), | |||
| 'user_input_form': user_input_form, | |||
| 'sensitive_word_avoidance': features_dict.get('sensitive_word_avoidance', | |||
| {"enabled": False, "type": "", "configs": []}), | |||
| 'file_upload': features_dict.get('file_upload', {"image": { | |||
| "enabled": False, | |||
| "number_limits": 3, | |||
| "detail": "high", | |||
| "transfer_methods": ["remote_url", "local_file"] | |||
| }}), | |||
| 'system_parameters': { | |||
| 'image_file_size_limit': dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT | |||
| } | |||
| "opening_statement": features_dict.get("opening_statement"), | |||
| "suggested_questions": features_dict.get("suggested_questions", []), | |||
| "suggested_questions_after_answer": features_dict.get( | |||
| "suggested_questions_after_answer", {"enabled": False} | |||
| ), | |||
| "speech_to_text": features_dict.get("speech_to_text", {"enabled": False}), | |||
| "text_to_speech": features_dict.get("text_to_speech", {"enabled": False}), | |||
| "retriever_resource": features_dict.get("retriever_resource", {"enabled": False}), | |||
| "annotation_reply": features_dict.get("annotation_reply", {"enabled": False}), | |||
| "more_like_this": features_dict.get("more_like_this", {"enabled": False}), | |||
| "user_input_form": user_input_form, | |||
| "sensitive_word_avoidance": features_dict.get( | |||
| "sensitive_word_avoidance", {"enabled": False, "type": "", "configs": []} | |||
| ), | |||
| "file_upload": features_dict.get( | |||
| "file_upload", | |||
| { | |||
| "image": { | |||
| "enabled": False, | |||
| "number_limits": 3, | |||
| "detail": "high", | |||
| "transfer_methods": ["remote_url", "local_file"], | |||
| } | |||
| }, | |||
| ), | |||
| "system_parameters": {"image_file_size_limit": dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT}, | |||
| } | |||
| @@ -90,6 +93,7 @@ class ExploreAppMetaApi(InstalledAppResource): | |||
| return AppService().get_app_meta(app_model) | |||
| api.add_resource(AppParameterApi, '/installed-apps/<uuid:installed_app_id>/parameters', | |||
| endpoint='installed_app_parameters') | |||
| api.add_resource(ExploreAppMetaApi, '/installed-apps/<uuid:installed_app_id>/meta', endpoint='installed_app_meta') | |||
| api.add_resource( | |||
| AppParameterApi, "/installed-apps/<uuid:installed_app_id>/parameters", endpoint="installed_app_parameters" | |||
| ) | |||
| api.add_resource(ExploreAppMetaApi, "/installed-apps/<uuid:installed_app_id>/meta", endpoint="installed_app_meta") | |||
| @@ -8,28 +8,28 @@ from libs.login import login_required | |||
| from services.recommended_app_service import RecommendedAppService | |||
| app_fields = { | |||
| 'id': fields.String, | |||
| 'name': fields.String, | |||
| 'mode': fields.String, | |||
| 'icon': fields.String, | |||
| 'icon_background': fields.String | |||
| "id": fields.String, | |||
| "name": fields.String, | |||
| "mode": fields.String, | |||
| "icon": fields.String, | |||
| "icon_background": fields.String, | |||
| } | |||
| recommended_app_fields = { | |||
| 'app': fields.Nested(app_fields, attribute='app'), | |||
| 'app_id': fields.String, | |||
| 'description': fields.String(attribute='description'), | |||
| 'copyright': fields.String, | |||
| 'privacy_policy': fields.String, | |||
| 'custom_disclaimer': fields.String, | |||
| 'category': fields.String, | |||
| 'position': fields.Integer, | |||
| 'is_listed': fields.Boolean | |||
| "app": fields.Nested(app_fields, attribute="app"), | |||
| "app_id": fields.String, | |||
| "description": fields.String(attribute="description"), | |||
| "copyright": fields.String, | |||
| "privacy_policy": fields.String, | |||
| "custom_disclaimer": fields.String, | |||
| "category": fields.String, | |||
| "position": fields.Integer, | |||
| "is_listed": fields.Boolean, | |||
| } | |||
| recommended_app_list_fields = { | |||
| 'recommended_apps': fields.List(fields.Nested(recommended_app_fields)), | |||
| 'categories': fields.List(fields.String) | |||
| "recommended_apps": fields.List(fields.Nested(recommended_app_fields)), | |||
| "categories": fields.List(fields.String), | |||
| } | |||
| @@ -40,11 +40,11 @@ class RecommendedAppListApi(Resource): | |||
| def get(self): | |||
| # language args | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('language', type=str, location='args') | |||
| parser.add_argument("language", type=str, location="args") | |||
| args = parser.parse_args() | |||
| if args.get('language') and args.get('language') in languages: | |||
| language_prefix = args.get('language') | |||
| if args.get("language") and args.get("language") in languages: | |||
| language_prefix = args.get("language") | |||
| elif current_user and current_user.interface_language: | |||
| language_prefix = current_user.interface_language | |||
| else: | |||
| @@ -61,5 +61,5 @@ class RecommendedAppApi(Resource): | |||
| return RecommendedAppService.get_recommend_app_detail(app_id) | |||
| api.add_resource(RecommendedAppListApi, '/explore/apps') | |||
| api.add_resource(RecommendedAppApi, '/explore/apps/<uuid:app_id>') | |||
| api.add_resource(RecommendedAppListApi, "/explore/apps") | |||
| api.add_resource(RecommendedAppApi, "/explore/apps/<uuid:app_id>") | |||
| @@ -11,56 +11,54 @@ from libs.helper import TimestampField, uuid_value | |||
| from services.errors.message import MessageNotExistsError | |||
| from services.saved_message_service import SavedMessageService | |||
| feedback_fields = { | |||
| 'rating': fields.String | |||
| } | |||
| feedback_fields = {"rating": fields.String} | |||
| message_fields = { | |||
| 'id': fields.String, | |||
| 'inputs': fields.Raw, | |||
| 'query': fields.String, | |||
| 'answer': fields.String, | |||
| 'message_files': fields.List(fields.Nested(message_file_fields), attribute='files'), | |||
| 'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True), | |||
| 'created_at': TimestampField | |||
| "id": fields.String, | |||
| "inputs": fields.Raw, | |||
| "query": fields.String, | |||
| "answer": fields.String, | |||
| "message_files": fields.List(fields.Nested(message_file_fields), attribute="files"), | |||
| "feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True), | |||
| "created_at": TimestampField, | |||
| } | |||
| class SavedMessageListApi(InstalledAppResource): | |||
| saved_message_infinite_scroll_pagination_fields = { | |||
| 'limit': fields.Integer, | |||
| 'has_more': fields.Boolean, | |||
| 'data': fields.List(fields.Nested(message_fields)) | |||
| "limit": fields.Integer, | |||
| "has_more": fields.Boolean, | |||
| "data": fields.List(fields.Nested(message_fields)), | |||
| } | |||
| @marshal_with(saved_message_infinite_scroll_pagination_fields) | |||
| def get(self, installed_app): | |||
| app_model = installed_app.app | |||
| if app_model.mode != 'completion': | |||
| if app_model.mode != "completion": | |||
| raise NotCompletionAppError() | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('last_id', type=uuid_value, location='args') | |||
| parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args') | |||
| parser.add_argument("last_id", type=uuid_value, location="args") | |||
| parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") | |||
| args = parser.parse_args() | |||
| return SavedMessageService.pagination_by_last_id(app_model, current_user, args['last_id'], args['limit']) | |||
| return SavedMessageService.pagination_by_last_id(app_model, current_user, args["last_id"], args["limit"]) | |||
| def post(self, installed_app): | |||
| app_model = installed_app.app | |||
| if app_model.mode != 'completion': | |||
| if app_model.mode != "completion": | |||
| raise NotCompletionAppError() | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('message_id', type=uuid_value, required=True, location='json') | |||
| parser.add_argument("message_id", type=uuid_value, required=True, location="json") | |||
| args = parser.parse_args() | |||
| try: | |||
| SavedMessageService.save(app_model, current_user, args['message_id']) | |||
| SavedMessageService.save(app_model, current_user, args["message_id"]) | |||
| except MessageNotExistsError: | |||
| raise NotFound("Message Not Exists.") | |||
| return {'result': 'success'} | |||
| return {"result": "success"} | |||
| class SavedMessageApi(InstalledAppResource): | |||
| @@ -69,13 +67,21 @@ class SavedMessageApi(InstalledAppResource): | |||
| message_id = str(message_id) | |||
| if app_model.mode != 'completion': | |||
| if app_model.mode != "completion": | |||
| raise NotCompletionAppError() | |||
| SavedMessageService.delete(app_model, current_user, message_id) | |||
| return {'result': 'success'} | |||
| return {"result": "success"} | |||
| api.add_resource(SavedMessageListApi, '/installed-apps/<uuid:installed_app_id>/saved-messages', endpoint='installed_app_saved_messages') | |||
| api.add_resource(SavedMessageApi, '/installed-apps/<uuid:installed_app_id>/saved-messages/<uuid:message_id>', endpoint='installed_app_saved_message') | |||
| api.add_resource( | |||
| SavedMessageListApi, | |||
| "/installed-apps/<uuid:installed_app_id>/saved-messages", | |||
| endpoint="installed_app_saved_messages", | |||
| ) | |||
| api.add_resource( | |||
| SavedMessageApi, | |||
| "/installed-apps/<uuid:installed_app_id>/saved-messages/<uuid:message_id>", | |||
| endpoint="installed_app_saved_message", | |||
| ) | |||
| @@ -35,17 +35,13 @@ class InstalledAppWorkflowRunApi(InstalledAppResource): | |||
| raise NotWorkflowAppError() | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('inputs', type=dict, required=True, nullable=False, location='json') | |||
| parser.add_argument('files', type=list, required=False, location='json') | |||
| parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") | |||
| parser.add_argument("files", type=list, required=False, location="json") | |||
| args = parser.parse_args() | |||
| try: | |||
| response = AppGenerateService.generate( | |||
| app_model=app_model, | |||
| user=current_user, | |||
| args=args, | |||
| invoke_from=InvokeFrom.EXPLORE, | |||
| streaming=True | |||
| app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=True | |||
| ) | |||
| return helper.compact_generate_response(response) | |||
| @@ -76,10 +72,10 @@ class InstalledAppWorkflowTaskStopApi(InstalledAppResource): | |||
| AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id) | |||
| return { | |||
| "result": "success" | |||
| } | |||
| return {"result": "success"} | |||
| api.add_resource(InstalledAppWorkflowRunApi, '/installed-apps/<uuid:installed_app_id>/workflows/run') | |||
| api.add_resource(InstalledAppWorkflowTaskStopApi, '/installed-apps/<uuid:installed_app_id>/workflows/tasks/<string:task_id>/stop') | |||
| api.add_resource(InstalledAppWorkflowRunApi, "/installed-apps/<uuid:installed_app_id>/workflows/run") | |||
| api.add_resource( | |||
| InstalledAppWorkflowTaskStopApi, "/installed-apps/<uuid:installed_app_id>/workflows/tasks/<string:task_id>/stop" | |||
| ) | |||
| @@ -14,29 +14,33 @@ def installed_app_required(view=None): | |||
| def decorator(view): | |||
| @wraps(view) | |||
| def decorated(*args, **kwargs): | |||
| if not kwargs.get('installed_app_id'): | |||
| raise ValueError('missing installed_app_id in path parameters') | |||
| if not kwargs.get("installed_app_id"): | |||
| raise ValueError("missing installed_app_id in path parameters") | |||
| installed_app_id = kwargs.get('installed_app_id') | |||
| installed_app_id = kwargs.get("installed_app_id") | |||
| installed_app_id = str(installed_app_id) | |||
| del kwargs['installed_app_id'] | |||
| del kwargs["installed_app_id"] | |||
| installed_app = db.session.query(InstalledApp).filter( | |||
| InstalledApp.id == str(installed_app_id), | |||
| InstalledApp.tenant_id == current_user.current_tenant_id | |||
| ).first() | |||
| installed_app = ( | |||
| db.session.query(InstalledApp) | |||
| .filter( | |||
| InstalledApp.id == str(installed_app_id), InstalledApp.tenant_id == current_user.current_tenant_id | |||
| ) | |||
| .first() | |||
| ) | |||
| if installed_app is None: | |||
| raise NotFound('Installed app not found') | |||
| raise NotFound("Installed app not found") | |||
| if not installed_app.app: | |||
| db.session.delete(installed_app) | |||
| db.session.commit() | |||
| raise NotFound('Installed app not found') | |||
| raise NotFound("Installed app not found") | |||
| return view(installed_app, *args, **kwargs) | |||
| return decorated | |||
| if view: | |||
| @@ -13,23 +13,18 @@ from services.code_based_extension_service import CodeBasedExtensionService | |||
| class CodeBasedExtensionAPI(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| def get(self): | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('module', type=str, required=True, location='args') | |||
| parser.add_argument("module", type=str, required=True, location="args") | |||
| args = parser.parse_args() | |||
| return { | |||
| 'module': args['module'], | |||
| 'data': CodeBasedExtensionService.get_code_based_extension(args['module']) | |||
| } | |||
| return {"module": args["module"], "data": CodeBasedExtensionService.get_code_based_extension(args["module"])} | |||
| class APIBasedExtensionAPI(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @@ -44,23 +39,22 @@ class APIBasedExtensionAPI(Resource): | |||
| @marshal_with(api_based_extension_fields) | |||
| def post(self): | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('name', type=str, required=True, location='json') | |||
| parser.add_argument('api_endpoint', type=str, required=True, location='json') | |||
| parser.add_argument('api_key', type=str, required=True, location='json') | |||
| parser.add_argument("name", type=str, required=True, location="json") | |||
| parser.add_argument("api_endpoint", type=str, required=True, location="json") | |||
| parser.add_argument("api_key", type=str, required=True, location="json") | |||
| args = parser.parse_args() | |||
| extension_data = APIBasedExtension( | |||
| tenant_id=current_user.current_tenant_id, | |||
| name=args['name'], | |||
| api_endpoint=args['api_endpoint'], | |||
| api_key=args['api_key'] | |||
| name=args["name"], | |||
| api_endpoint=args["api_endpoint"], | |||
| api_key=args["api_key"], | |||
| ) | |||
| return APIBasedExtensionService.save(extension_data) | |||
| class APIBasedExtensionDetailAPI(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @@ -82,16 +76,16 @@ class APIBasedExtensionDetailAPI(Resource): | |||
| extension_data_from_db = APIBasedExtensionService.get_with_tenant_id(tenant_id, api_based_extension_id) | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('name', type=str, required=True, location='json') | |||
| parser.add_argument('api_endpoint', type=str, required=True, location='json') | |||
| parser.add_argument('api_key', type=str, required=True, location='json') | |||
| parser.add_argument("name", type=str, required=True, location="json") | |||
| parser.add_argument("api_endpoint", type=str, required=True, location="json") | |||
| parser.add_argument("api_key", type=str, required=True, location="json") | |||
| args = parser.parse_args() | |||
| extension_data_from_db.name = args['name'] | |||
| extension_data_from_db.api_endpoint = args['api_endpoint'] | |||
| extension_data_from_db.name = args["name"] | |||
| extension_data_from_db.api_endpoint = args["api_endpoint"] | |||
| if args['api_key'] != HIDDEN_VALUE: | |||
| extension_data_from_db.api_key = args['api_key'] | |||
| if args["api_key"] != HIDDEN_VALUE: | |||
| extension_data_from_db.api_key = args["api_key"] | |||
| return APIBasedExtensionService.save(extension_data_from_db) | |||
| @@ -106,10 +100,10 @@ class APIBasedExtensionDetailAPI(Resource): | |||
| APIBasedExtensionService.delete(extension_data_from_db) | |||
| return {'result': 'success'} | |||
| return {"result": "success"} | |||
| api.add_resource(CodeBasedExtensionAPI, '/code-based-extension') | |||
| api.add_resource(CodeBasedExtensionAPI, "/code-based-extension") | |||
| api.add_resource(APIBasedExtensionAPI, '/api-based-extension') | |||
| api.add_resource(APIBasedExtensionDetailAPI, '/api-based-extension/<uuid:id>') | |||
| api.add_resource(APIBasedExtensionAPI, "/api-based-extension") | |||
| api.add_resource(APIBasedExtensionDetailAPI, "/api-based-extension/<uuid:id>") | |||
| @@ -10,7 +10,6 @@ from .wraps import account_initialization_required, cloud_utm_record | |||
| class FeatureApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @@ -24,5 +23,5 @@ class SystemFeatureApi(Resource): | |||
| return FeatureService.get_system_features().model_dump() | |||
| api.add_resource(FeatureApi, '/features') | |||
| api.add_resource(SystemFeatureApi, '/system-features') | |||
| api.add_resource(FeatureApi, "/features") | |||
| api.add_resource(SystemFeatureApi, "/system-features") | |||
| @@ -14,12 +14,11 @@ from .wraps import only_edition_self_hosted | |||
| class InitValidateAPI(Resource): | |||
| def get(self): | |||
| init_status = get_init_validate_status() | |||
| if init_status: | |||
| return { 'status': 'finished' } | |||
| return {'status': 'not_started' } | |||
| return {"status": "finished"} | |||
| return {"status": "not_started"} | |||
| @only_edition_self_hosted | |||
| def post(self): | |||
| @@ -29,22 +28,23 @@ class InitValidateAPI(Resource): | |||
| raise AlreadySetupError() | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('password', type=str_len(30), | |||
| required=True, location='json') | |||
| input_password = parser.parse_args()['password'] | |||
| parser.add_argument("password", type=str_len(30), required=True, location="json") | |||
| input_password = parser.parse_args()["password"] | |||
| if input_password != os.environ.get('INIT_PASSWORD'): | |||
| session['is_init_validated'] = False | |||
| if input_password != os.environ.get("INIT_PASSWORD"): | |||
| session["is_init_validated"] = False | |||
| raise InitValidateFailedError() | |||
| session['is_init_validated'] = True | |||
| return {'result': 'success'}, 201 | |||
| session["is_init_validated"] = True | |||
| return {"result": "success"}, 201 | |||
| def get_init_validate_status(): | |||
| if dify_config.EDITION == 'SELF_HOSTED': | |||
| if os.environ.get('INIT_PASSWORD'): | |||
| return session.get('is_init_validated') or DifySetup.query.first() | |||
| if dify_config.EDITION == "SELF_HOSTED": | |||
| if os.environ.get("INIT_PASSWORD"): | |||
| return session.get("is_init_validated") or DifySetup.query.first() | |||
| return True | |||
| api.add_resource(InitValidateAPI, '/init') | |||
| api.add_resource(InitValidateAPI, "/init") | |||
| @@ -4,14 +4,11 @@ from controllers.console import api | |||
| class PingApi(Resource): | |||
| def get(self): | |||
| """ | |||
| For connection health check | |||
| """ | |||
| return { | |||
| "result": "pong" | |||
| } | |||
| return {"result": "pong"} | |||
| api.add_resource(PingApi, '/ping') | |||
| api.add_resource(PingApi, "/ping") | |||
| @@ -16,17 +16,13 @@ from .wraps import only_edition_self_hosted | |||
| class SetupApi(Resource): | |||
| def get(self): | |||
| if dify_config.EDITION == 'SELF_HOSTED': | |||
| if dify_config.EDITION == "SELF_HOSTED": | |||
| setup_status = get_setup_status() | |||
| if setup_status: | |||
| return { | |||
| 'step': 'finished', | |||
| 'setup_at': setup_status.setup_at.isoformat() | |||
| } | |||
| return {'step': 'not_started'} | |||
| return {'step': 'finished'} | |||
| return {"step": "finished", "setup_at": setup_status.setup_at.isoformat()} | |||
| return {"step": "not_started"} | |||
| return {"step": "finished"} | |||
| @only_edition_self_hosted | |||
| def post(self): | |||
| @@ -38,28 +34,22 @@ class SetupApi(Resource): | |||
| tenant_count = TenantService.get_tenant_count() | |||
| if tenant_count > 0: | |||
| raise AlreadySetupError() | |||
| if not get_init_validate_status(): | |||
| raise NotInitValidateError() | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('email', type=email, | |||
| required=True, location='json') | |||
| parser.add_argument('name', type=str_len( | |||
| 30), required=True, location='json') | |||
| parser.add_argument('password', type=valid_password, | |||
| required=True, location='json') | |||
| parser.add_argument("email", type=email, required=True, location="json") | |||
| parser.add_argument("name", type=str_len(30), required=True, location="json") | |||
| parser.add_argument("password", type=valid_password, required=True, location="json") | |||
| args = parser.parse_args() | |||
| # setup | |||
| RegisterService.setup( | |||
| email=args['email'], | |||
| name=args['name'], | |||
| password=args['password'], | |||
| ip_address=get_remote_ip(request) | |||
| email=args["email"], name=args["name"], password=args["password"], ip_address=get_remote_ip(request) | |||
| ) | |||
| return {'result': 'success'}, 201 | |||
| return {"result": "success"}, 201 | |||
| def setup_required(view): | |||
| @@ -68,7 +58,7 @@ def setup_required(view): | |||
| # check setup | |||
| if not get_init_validate_status(): | |||
| raise NotInitValidateError() | |||
| elif not get_setup_status(): | |||
| raise NotSetupError() | |||
| @@ -78,9 +68,10 @@ def setup_required(view): | |||
| def get_setup_status(): | |||
| if dify_config.EDITION == 'SELF_HOSTED': | |||
| if dify_config.EDITION == "SELF_HOSTED": | |||
| return DifySetup.query.first() | |||
| else: | |||
| return True | |||
| api.add_resource(SetupApi, '/setup') | |||
| api.add_resource(SetupApi, "/setup") | |||
| @@ -14,19 +14,18 @@ from services.tag_service import TagService | |||
| def _validate_name(name): | |||
| if not name or len(name) < 1 or len(name) > 40: | |||
| raise ValueError('Name must be between 1 to 50 characters.') | |||
| raise ValueError("Name must be between 1 to 50 characters.") | |||
| return name | |||
| class TagListApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @marshal_with(tag_fields) | |||
| def get(self): | |||
| tag_type = request.args.get('type', type=str) | |||
| keyword = request.args.get('keyword', default=None, type=str) | |||
| tag_type = request.args.get("type", type=str) | |||
| keyword = request.args.get("keyword", default=None, type=str) | |||
| tags = TagService.get_tags(tag_type, current_user.current_tenant_id, keyword) | |||
| return tags, 200 | |||
| @@ -40,28 +39,21 @@ class TagListApi(Resource): | |||
| raise Forbidden() | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('name', nullable=False, required=True, | |||
| help='Name must be between 1 to 50 characters.', | |||
| type=_validate_name) | |||
| parser.add_argument('type', type=str, location='json', | |||
| choices=Tag.TAG_TYPE_LIST, | |||
| nullable=True, | |||
| help='Invalid tag type.') | |||
| parser.add_argument( | |||
| "name", nullable=False, required=True, help="Name must be between 1 to 50 characters.", type=_validate_name | |||
| ) | |||
| parser.add_argument( | |||
| "type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type." | |||
| ) | |||
| args = parser.parse_args() | |||
| tag = TagService.save_tags(args) | |||
| response = { | |||
| 'id': tag.id, | |||
| 'name': tag.name, | |||
| 'type': tag.type, | |||
| 'binding_count': 0 | |||
| } | |||
| response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0} | |||
| return response, 200 | |||
| class TagUpdateDeleteApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @@ -72,20 +64,15 @@ class TagUpdateDeleteApi(Resource): | |||
| raise Forbidden() | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('name', nullable=False, required=True, | |||
| help='Name must be between 1 to 50 characters.', | |||
| type=_validate_name) | |||
| parser.add_argument( | |||
| "name", nullable=False, required=True, help="Name must be between 1 to 50 characters.", type=_validate_name | |||
| ) | |||
| args = parser.parse_args() | |||
| tag = TagService.update_tags(args, tag_id) | |||
| binding_count = TagService.get_tag_binding_count(tag_id) | |||
| response = { | |||
| 'id': tag.id, | |||
| 'name': tag.name, | |||
| 'type': tag.type, | |||
| 'binding_count': binding_count | |||
| } | |||
| response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": binding_count} | |||
| return response, 200 | |||
| @@ -104,7 +91,6 @@ class TagUpdateDeleteApi(Resource): | |||
| class TagBindingCreateApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @@ -114,14 +100,15 @@ class TagBindingCreateApi(Resource): | |||
| raise Forbidden() | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('tag_ids', type=list, nullable=False, required=True, location='json', | |||
| help='Tag IDs is required.') | |||
| parser.add_argument('target_id', type=str, nullable=False, required=True, location='json', | |||
| help='Target ID is required.') | |||
| parser.add_argument('type', type=str, location='json', | |||
| choices=Tag.TAG_TYPE_LIST, | |||
| nullable=True, | |||
| help='Invalid tag type.') | |||
| parser.add_argument( | |||
| "tag_ids", type=list, nullable=False, required=True, location="json", help="Tag IDs is required." | |||
| ) | |||
| parser.add_argument( | |||
| "target_id", type=str, nullable=False, required=True, location="json", help="Target ID is required." | |||
| ) | |||
| parser.add_argument( | |||
| "type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type." | |||
| ) | |||
| args = parser.parse_args() | |||
| TagService.save_tag_binding(args) | |||
| @@ -129,7 +116,6 @@ class TagBindingCreateApi(Resource): | |||
| class TagBindingDeleteApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @@ -139,21 +125,18 @@ class TagBindingDeleteApi(Resource): | |||
| raise Forbidden() | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('tag_id', type=str, nullable=False, required=True, | |||
| help='Tag ID is required.') | |||
| parser.add_argument('target_id', type=str, nullable=False, required=True, | |||
| help='Target ID is required.') | |||
| parser.add_argument('type', type=str, location='json', | |||
| choices=Tag.TAG_TYPE_LIST, | |||
| nullable=True, | |||
| help='Invalid tag type.') | |||
| parser.add_argument("tag_id", type=str, nullable=False, required=True, help="Tag ID is required.") | |||
| parser.add_argument("target_id", type=str, nullable=False, required=True, help="Target ID is required.") | |||
| parser.add_argument( | |||
| "type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type." | |||
| ) | |||
| args = parser.parse_args() | |||
| TagService.delete_tag_binding(args) | |||
| return 200 | |||
| api.add_resource(TagListApi, '/tags') | |||
| api.add_resource(TagUpdateDeleteApi, '/tags/<uuid:tag_id>') | |||
| api.add_resource(TagBindingCreateApi, '/tag-bindings/create') | |||
| api.add_resource(TagBindingDeleteApi, '/tag-bindings/remove') | |||
| api.add_resource(TagListApi, "/tags") | |||
| api.add_resource(TagUpdateDeleteApi, "/tags/<uuid:tag_id>") | |||
| api.add_resource(TagBindingCreateApi, "/tag-bindings/create") | |||
| api.add_resource(TagBindingDeleteApi, "/tag-bindings/remove") | |||
| @@ -1,4 +1,3 @@ | |||
| import json | |||
| import logging | |||
| @@ -11,42 +10,39 @@ from . import api | |||
| class VersionApi(Resource): | |||
| def get(self): | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('current_version', type=str, required=True, location='args') | |||
| parser.add_argument("current_version", type=str, required=True, location="args") | |||
| args = parser.parse_args() | |||
| check_update_url = dify_config.CHECK_UPDATE_URL | |||
| result = { | |||
| 'version': dify_config.CURRENT_VERSION, | |||
| 'release_date': '', | |||
| 'release_notes': '', | |||
| 'can_auto_update': False, | |||
| 'features': { | |||
| 'can_replace_logo': dify_config.CAN_REPLACE_LOGO, | |||
| 'model_load_balancing_enabled': dify_config.MODEL_LB_ENABLED | |||
| } | |||
| "version": dify_config.CURRENT_VERSION, | |||
| "release_date": "", | |||
| "release_notes": "", | |||
| "can_auto_update": False, | |||
| "features": { | |||
| "can_replace_logo": dify_config.CAN_REPLACE_LOGO, | |||
| "model_load_balancing_enabled": dify_config.MODEL_LB_ENABLED, | |||
| }, | |||
| } | |||
| if not check_update_url: | |||
| return result | |||
| try: | |||
| response = requests.get(check_update_url, { | |||
| 'current_version': args.get('current_version') | |||
| }) | |||
| response = requests.get(check_update_url, {"current_version": args.get("current_version")}) | |||
| except Exception as error: | |||
| logging.warning("Check update version error: {}.".format(str(error))) | |||
| result['version'] = args.get('current_version') | |||
| result["version"] = args.get("current_version") | |||
| return result | |||
| content = json.loads(response.content) | |||
| result['version'] = content['version'] | |||
| result['release_date'] = content['releaseDate'] | |||
| result['release_notes'] = content['releaseNotes'] | |||
| result['can_auto_update'] = content['canAutoUpdate'] | |||
| result["version"] = content["version"] | |||
| result["release_date"] = content["releaseDate"] | |||
| result["release_notes"] = content["releaseNotes"] | |||
| result["can_auto_update"] = content["canAutoUpdate"] | |||
| return result | |||
| api.add_resource(VersionApi, '/version') | |||
| api.add_resource(VersionApi, "/version") | |||
| @@ -26,52 +26,53 @@ from services.errors.account import CurrentPasswordIncorrectError as ServiceCurr | |||
| class AccountInitApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| def post(self): | |||
| account = current_user | |||
| if account.status == 'active': | |||
| if account.status == "active": | |||
| raise AccountAlreadyInitedError() | |||
| parser = reqparse.RequestParser() | |||
| if dify_config.EDITION == 'CLOUD': | |||
| parser.add_argument('invitation_code', type=str, location='json') | |||
| if dify_config.EDITION == "CLOUD": | |||
| parser.add_argument("invitation_code", type=str, location="json") | |||
| parser.add_argument( | |||
| 'interface_language', type=supported_language, required=True, location='json') | |||
| parser.add_argument('timezone', type=timezone, | |||
| required=True, location='json') | |||
| parser.add_argument("interface_language", type=supported_language, required=True, location="json") | |||
| parser.add_argument("timezone", type=timezone, required=True, location="json") | |||
| args = parser.parse_args() | |||
| if dify_config.EDITION == 'CLOUD': | |||
| if not args['invitation_code']: | |||
| raise ValueError('invitation_code is required') | |||
| if dify_config.EDITION == "CLOUD": | |||
| if not args["invitation_code"]: | |||
| raise ValueError("invitation_code is required") | |||
| # check invitation code | |||
| invitation_code = db.session.query(InvitationCode).filter( | |||
| InvitationCode.code == args['invitation_code'], | |||
| InvitationCode.status == 'unused', | |||
| ).first() | |||
| invitation_code = ( | |||
| db.session.query(InvitationCode) | |||
| .filter( | |||
| InvitationCode.code == args["invitation_code"], | |||
| InvitationCode.status == "unused", | |||
| ) | |||
| .first() | |||
| ) | |||
| if not invitation_code: | |||
| raise InvalidInvitationCodeError() | |||
| invitation_code.status = 'used' | |||
| invitation_code.status = "used" | |||
| invitation_code.used_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) | |||
| invitation_code.used_by_tenant_id = account.current_tenant_id | |||
| invitation_code.used_by_account_id = account.id | |||
| account.interface_language = args['interface_language'] | |||
| account.timezone = args['timezone'] | |||
| account.interface_theme = 'light' | |||
| account.status = 'active' | |||
| account.interface_language = args["interface_language"] | |||
| account.timezone = args["timezone"] | |||
| account.interface_theme = "light" | |||
| account.status = "active" | |||
| account.initialized_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) | |||
| db.session.commit() | |||
| return {'result': 'success'} | |||
| return {"result": "success"} | |||
| class AccountProfileApi(Resource): | |||
| @@ -90,15 +91,14 @@ class AccountNameApi(Resource): | |||
| @marshal_with(account_fields) | |||
| def post(self): | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('name', type=str, required=True, location='json') | |||
| parser.add_argument("name", type=str, required=True, location="json") | |||
| args = parser.parse_args() | |||
| # Validate account name length | |||
| if len(args['name']) < 3 or len(args['name']) > 30: | |||
| raise ValueError( | |||
| "Account name must be between 3 and 30 characters.") | |||
| if len(args["name"]) < 3 or len(args["name"]) > 30: | |||
| raise ValueError("Account name must be between 3 and 30 characters.") | |||
| updated_account = AccountService.update_account(current_user, name=args['name']) | |||
| updated_account = AccountService.update_account(current_user, name=args["name"]) | |||
| return updated_account | |||
| @@ -110,10 +110,10 @@ class AccountAvatarApi(Resource): | |||
| @marshal_with(account_fields) | |||
| def post(self): | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('avatar', type=str, required=True, location='json') | |||
| parser.add_argument("avatar", type=str, required=True, location="json") | |||
| args = parser.parse_args() | |||
| updated_account = AccountService.update_account(current_user, avatar=args['avatar']) | |||
| updated_account = AccountService.update_account(current_user, avatar=args["avatar"]) | |||
| return updated_account | |||
| @@ -125,11 +125,10 @@ class AccountInterfaceLanguageApi(Resource): | |||
| @marshal_with(account_fields) | |||
| def post(self): | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument( | |||
| 'interface_language', type=supported_language, required=True, location='json') | |||
| parser.add_argument("interface_language", type=supported_language, required=True, location="json") | |||
| args = parser.parse_args() | |||
| updated_account = AccountService.update_account(current_user, interface_language=args['interface_language']) | |||
| updated_account = AccountService.update_account(current_user, interface_language=args["interface_language"]) | |||
| return updated_account | |||
| @@ -141,11 +140,10 @@ class AccountInterfaceThemeApi(Resource): | |||
| @marshal_with(account_fields) | |||
| def post(self): | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('interface_theme', type=str, choices=[ | |||
| 'light', 'dark'], required=True, location='json') | |||
| parser.add_argument("interface_theme", type=str, choices=["light", "dark"], required=True, location="json") | |||
| args = parser.parse_args() | |||
| updated_account = AccountService.update_account(current_user, interface_theme=args['interface_theme']) | |||
| updated_account = AccountService.update_account(current_user, interface_theme=args["interface_theme"]) | |||
| return updated_account | |||
| @@ -157,15 +155,14 @@ class AccountTimezoneApi(Resource): | |||
| @marshal_with(account_fields) | |||
| def post(self): | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('timezone', type=str, | |||
| required=True, location='json') | |||
| parser.add_argument("timezone", type=str, required=True, location="json") | |||
| args = parser.parse_args() | |||
| # Validate timezone string, e.g. America/New_York, Asia/Shanghai | |||
| if args['timezone'] not in pytz.all_timezones: | |||
| if args["timezone"] not in pytz.all_timezones: | |||
| raise ValueError("Invalid timezone string.") | |||
| updated_account = AccountService.update_account(current_user, timezone=args['timezone']) | |||
| updated_account = AccountService.update_account(current_user, timezone=args["timezone"]) | |||
| return updated_account | |||
| @@ -177,20 +174,16 @@ class AccountPasswordApi(Resource): | |||
| @marshal_with(account_fields) | |||
| def post(self): | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('password', type=str, | |||
| required=False, location='json') | |||
| parser.add_argument('new_password', type=str, | |||
| required=True, location='json') | |||
| parser.add_argument('repeat_new_password', type=str, | |||
| required=True, location='json') | |||
| parser.add_argument("password", type=str, required=False, location="json") | |||
| parser.add_argument("new_password", type=str, required=True, location="json") | |||
| parser.add_argument("repeat_new_password", type=str, required=True, location="json") | |||
| args = parser.parse_args() | |||
| if args['new_password'] != args['repeat_new_password']: | |||
| if args["new_password"] != args["repeat_new_password"]: | |||
| raise RepeatPasswordNotMatchError() | |||
| try: | |||
| AccountService.update_account_password( | |||
| current_user, args['password'], args['new_password']) | |||
| AccountService.update_account_password(current_user, args["password"], args["new_password"]) | |||
| except ServiceCurrentPasswordIncorrectError: | |||
| raise CurrentPasswordIncorrectError() | |||
| @@ -199,14 +192,14 @@ class AccountPasswordApi(Resource): | |||
| class AccountIntegrateApi(Resource): | |||
| integrate_fields = { | |||
| 'provider': fields.String, | |||
| 'created_at': TimestampField, | |||
| 'is_bound': fields.Boolean, | |||
| 'link': fields.String | |||
| "provider": fields.String, | |||
| "created_at": TimestampField, | |||
| "is_bound": fields.Boolean, | |||
| "link": fields.String, | |||
| } | |||
| integrate_list_fields = { | |||
| 'data': fields.List(fields.Nested(integrate_fields)), | |||
| "data": fields.List(fields.Nested(integrate_fields)), | |||
| } | |||
| @setup_required | |||
| @@ -216,10 +209,9 @@ class AccountIntegrateApi(Resource): | |||
| def get(self): | |||
| account = current_user | |||
| account_integrates = db.session.query(AccountIntegrate).filter( | |||
| AccountIntegrate.account_id == account.id).all() | |||
| account_integrates = db.session.query(AccountIntegrate).filter(AccountIntegrate.account_id == account.id).all() | |||
| base_url = request.url_root.rstrip('/') | |||
| base_url = request.url_root.rstrip("/") | |||
| oauth_base_path = "/console/api/oauth/login" | |||
| providers = ["github", "google"] | |||
| @@ -227,36 +219,38 @@ class AccountIntegrateApi(Resource): | |||
| for provider in providers: | |||
| existing_integrate = next((ai for ai in account_integrates if ai.provider == provider), None) | |||
| if existing_integrate: | |||
| integrate_data.append({ | |||
| 'id': existing_integrate.id, | |||
| 'provider': provider, | |||
| 'created_at': existing_integrate.created_at, | |||
| 'is_bound': True, | |||
| 'link': None | |||
| }) | |||
| integrate_data.append( | |||
| { | |||
| "id": existing_integrate.id, | |||
| "provider": provider, | |||
| "created_at": existing_integrate.created_at, | |||
| "is_bound": True, | |||
| "link": None, | |||
| } | |||
| ) | |||
| else: | |||
| integrate_data.append({ | |||
| 'id': None, | |||
| 'provider': provider, | |||
| 'created_at': None, | |||
| 'is_bound': False, | |||
| 'link': f'{base_url}{oauth_base_path}/{provider}' | |||
| }) | |||
| return {'data': integrate_data} | |||
| integrate_data.append( | |||
| { | |||
| "id": None, | |||
| "provider": provider, | |||
| "created_at": None, | |||
| "is_bound": False, | |||
| "link": f"{base_url}{oauth_base_path}/{provider}", | |||
| } | |||
| ) | |||
| return {"data": integrate_data} | |||
| # Register API resources | |||
| api.add_resource(AccountInitApi, '/account/init') | |||
| api.add_resource(AccountProfileApi, '/account/profile') | |||
| api.add_resource(AccountNameApi, '/account/name') | |||
| api.add_resource(AccountAvatarApi, '/account/avatar') | |||
| api.add_resource(AccountInterfaceLanguageApi, '/account/interface-language') | |||
| api.add_resource(AccountInterfaceThemeApi, '/account/interface-theme') | |||
| api.add_resource(AccountTimezoneApi, '/account/timezone') | |||
| api.add_resource(AccountPasswordApi, '/account/password') | |||
| api.add_resource(AccountIntegrateApi, '/account/integrates') | |||
| api.add_resource(AccountInitApi, "/account/init") | |||
| api.add_resource(AccountProfileApi, "/account/profile") | |||
| api.add_resource(AccountNameApi, "/account/name") | |||
| api.add_resource(AccountAvatarApi, "/account/avatar") | |||
| api.add_resource(AccountInterfaceLanguageApi, "/account/interface-language") | |||
| api.add_resource(AccountInterfaceThemeApi, "/account/interface-theme") | |||
| api.add_resource(AccountTimezoneApi, "/account/timezone") | |||
| api.add_resource(AccountPasswordApi, "/account/password") | |||
| api.add_resource(AccountIntegrateApi, "/account/integrates") | |||
| # api.add_resource(AccountEmailApi, '/account/email') | |||
| # api.add_resource(AccountEmailVerifyApi, '/account/email-verify') | |||
| @@ -2,36 +2,36 @@ from libs.exception import BaseHTTPException | |||
| class RepeatPasswordNotMatchError(BaseHTTPException): | |||
| error_code = 'repeat_password_not_match' | |||
| error_code = "repeat_password_not_match" | |||
| description = "New password and repeat password does not match." | |||
| code = 400 | |||
| class CurrentPasswordIncorrectError(BaseHTTPException): | |||
| error_code = 'current_password_incorrect' | |||
| error_code = "current_password_incorrect" | |||
| description = "Current password is incorrect." | |||
| code = 400 | |||
| class ProviderRequestFailedError(BaseHTTPException): | |||
| error_code = 'provider_request_failed' | |||
| error_code = "provider_request_failed" | |||
| description = None | |||
| code = 400 | |||
| class InvalidInvitationCodeError(BaseHTTPException): | |||
| error_code = 'invalid_invitation_code' | |||
| error_code = "invalid_invitation_code" | |||
| description = "Invalid invitation code." | |||
| code = 400 | |||
| class AccountAlreadyInitedError(BaseHTTPException): | |||
| error_code = 'account_already_inited' | |||
| error_code = "account_already_inited" | |||
| description = "The account has been initialized. Please refresh the page." | |||
| code = 400 | |||
| class AccountNotInitializedError(BaseHTTPException): | |||
| error_code = 'account_not_initialized' | |||
| error_code = "account_not_initialized" | |||
| description = "The account has not been initialized yet. Please proceed with the initialization process first." | |||
| code = 400 | |||
| @@ -22,10 +22,16 @@ class LoadBalancingCredentialsValidateApi(Resource): | |||
| tenant_id = current_user.current_tenant_id | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('model', type=str, required=True, nullable=False, location='json') | |||
| parser.add_argument('model_type', type=str, required=True, nullable=False, | |||
| choices=[mt.value for mt in ModelType], location='json') | |||
| parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json') | |||
| parser.add_argument("model", type=str, required=True, nullable=False, location="json") | |||
| parser.add_argument( | |||
| "model_type", | |||
| type=str, | |||
| required=True, | |||
| nullable=False, | |||
| choices=[mt.value for mt in ModelType], | |||
| location="json", | |||
| ) | |||
| parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") | |||
| args = parser.parse_args() | |||
| # validate model load balancing credentials | |||
| @@ -38,18 +44,18 @@ class LoadBalancingCredentialsValidateApi(Resource): | |||
| model_load_balancing_service.validate_load_balancing_credentials( | |||
| tenant_id=tenant_id, | |||
| provider=provider, | |||
| model=args['model'], | |||
| model_type=args['model_type'], | |||
| credentials=args['credentials'] | |||
| model=args["model"], | |||
| model_type=args["model_type"], | |||
| credentials=args["credentials"], | |||
| ) | |||
| except CredentialsValidateFailedError as ex: | |||
| result = False | |||
| error = str(ex) | |||
| response = {'result': 'success' if result else 'error'} | |||
| response = {"result": "success" if result else "error"} | |||
| if not result: | |||
| response['error'] = error | |||
| response["error"] = error | |||
| return response | |||
| @@ -65,10 +71,16 @@ class LoadBalancingConfigCredentialsValidateApi(Resource): | |||
| tenant_id = current_user.current_tenant_id | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('model', type=str, required=True, nullable=False, location='json') | |||
| parser.add_argument('model_type', type=str, required=True, nullable=False, | |||
| choices=[mt.value for mt in ModelType], location='json') | |||
| parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json') | |||
| parser.add_argument("model", type=str, required=True, nullable=False, location="json") | |||
| parser.add_argument( | |||
| "model_type", | |||
| type=str, | |||
| required=True, | |||
| nullable=False, | |||
| choices=[mt.value for mt in ModelType], | |||
| location="json", | |||
| ) | |||
| parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") | |||
| args = parser.parse_args() | |||
| # validate model load balancing config credentials | |||
| @@ -81,26 +93,30 @@ class LoadBalancingConfigCredentialsValidateApi(Resource): | |||
| model_load_balancing_service.validate_load_balancing_credentials( | |||
| tenant_id=tenant_id, | |||
| provider=provider, | |||
| model=args['model'], | |||
| model_type=args['model_type'], | |||
| credentials=args['credentials'], | |||
| model=args["model"], | |||
| model_type=args["model_type"], | |||
| credentials=args["credentials"], | |||
| config_id=config_id, | |||
| ) | |||
| except CredentialsValidateFailedError as ex: | |||
| result = False | |||
| error = str(ex) | |||
| response = {'result': 'success' if result else 'error'} | |||
| response = {"result": "success" if result else "error"} | |||
| if not result: | |||
| response['error'] = error | |||
| response["error"] = error | |||
| return response | |||
| # Load Balancing Config | |||
| api.add_resource(LoadBalancingCredentialsValidateApi, | |||
| '/workspaces/current/model-providers/<string:provider>/models/load-balancing-configs/credentials-validate') | |||
| api.add_resource(LoadBalancingConfigCredentialsValidateApi, | |||
| '/workspaces/current/model-providers/<string:provider>/models/load-balancing-configs/<string:config_id>/credentials-validate') | |||
| api.add_resource( | |||
| LoadBalancingCredentialsValidateApi, | |||
| "/workspaces/current/model-providers/<string:provider>/models/load-balancing-configs/credentials-validate", | |||
| ) | |||
| api.add_resource( | |||
| LoadBalancingConfigCredentialsValidateApi, | |||
| "/workspaces/current/model-providers/<string:provider>/models/load-balancing-configs/<string:config_id>/credentials-validate", | |||
| ) | |||
| @@ -23,7 +23,7 @@ class MemberListApi(Resource): | |||
| @marshal_with(account_with_role_list_fields) | |||
| def get(self): | |||
| members = TenantService.get_tenant_members(current_user.current_tenant) | |||
| return {'result': 'success', 'accounts': members}, 200 | |||
| return {"result": "success", "accounts": members}, 200 | |||
| class MemberInviteEmailApi(Resource): | |||
| @@ -32,48 +32,46 @@ class MemberInviteEmailApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @cloud_edition_billing_resource_check('members') | |||
| @cloud_edition_billing_resource_check("members") | |||
| def post(self): | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('emails', type=str, required=True, location='json', action='append') | |||
| parser.add_argument('role', type=str, required=True, default='admin', location='json') | |||
| parser.add_argument('language', type=str, required=False, location='json') | |||
| parser.add_argument("emails", type=str, required=True, location="json", action="append") | |||
| parser.add_argument("role", type=str, required=True, default="admin", location="json") | |||
| parser.add_argument("language", type=str, required=False, location="json") | |||
| args = parser.parse_args() | |||
| invitee_emails = args['emails'] | |||
| invitee_role = args['role'] | |||
| interface_language = args['language'] | |||
| invitee_emails = args["emails"] | |||
| invitee_role = args["role"] | |||
| interface_language = args["language"] | |||
| if not TenantAccountRole.is_non_owner_role(invitee_role): | |||
| return {'code': 'invalid-role', 'message': 'Invalid role'}, 400 | |||
| return {"code": "invalid-role", "message": "Invalid role"}, 400 | |||
| inviter = current_user | |||
| invitation_results = [] | |||
| console_web_url = dify_config.CONSOLE_WEB_URL | |||
| for invitee_email in invitee_emails: | |||
| try: | |||
| token = RegisterService.invite_new_member(inviter.current_tenant, invitee_email, interface_language, role=invitee_role, inviter=inviter) | |||
| invitation_results.append({ | |||
| 'status': 'success', | |||
| 'email': invitee_email, | |||
| 'url': f'{console_web_url}/activate?email={invitee_email}&token={token}' | |||
| }) | |||
| token = RegisterService.invite_new_member( | |||
| inviter.current_tenant, invitee_email, interface_language, role=invitee_role, inviter=inviter | |||
| ) | |||
| invitation_results.append( | |||
| { | |||
| "status": "success", | |||
| "email": invitee_email, | |||
| "url": f"{console_web_url}/activate?email={invitee_email}&token={token}", | |||
| } | |||
| ) | |||
| except AccountAlreadyInTenantError: | |||
| invitation_results.append({ | |||
| 'status': 'success', | |||
| 'email': invitee_email, | |||
| 'url': f'{console_web_url}/signin' | |||
| }) | |||
| invitation_results.append( | |||
| {"status": "success", "email": invitee_email, "url": f"{console_web_url}/signin"} | |||
| ) | |||
| break | |||
| except Exception as e: | |||
| invitation_results.append({ | |||
| 'status': 'failed', | |||
| 'email': invitee_email, | |||
| 'message': str(e) | |||
| }) | |||
| invitation_results.append({"status": "failed", "email": invitee_email, "message": str(e)}) | |||
| return { | |||
| 'result': 'success', | |||
| 'invitation_results': invitation_results, | |||
| "result": "success", | |||
| "invitation_results": invitation_results, | |||
| }, 201 | |||
| @@ -91,15 +89,15 @@ class MemberCancelInviteApi(Resource): | |||
| try: | |||
| TenantService.remove_member_from_tenant(current_user.current_tenant, member, current_user) | |||
| except services.errors.account.CannotOperateSelfError as e: | |||
| return {'code': 'cannot-operate-self', 'message': str(e)}, 400 | |||
| return {"code": "cannot-operate-self", "message": str(e)}, 400 | |||
| except services.errors.account.NoPermissionError as e: | |||
| return {'code': 'forbidden', 'message': str(e)}, 403 | |||
| return {"code": "forbidden", "message": str(e)}, 403 | |||
| except services.errors.account.MemberNotInTenantError as e: | |||
| return {'code': 'member-not-found', 'message': str(e)}, 404 | |||
| return {"code": "member-not-found", "message": str(e)}, 404 | |||
| except Exception as e: | |||
| raise ValueError(str(e)) | |||
| return {'result': 'success'}, 204 | |||
| return {"result": "success"}, 204 | |||
| class MemberUpdateRoleApi(Resource): | |||
| @@ -110,12 +108,12 @@ class MemberUpdateRoleApi(Resource): | |||
| @account_initialization_required | |||
| def put(self, member_id): | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('role', type=str, required=True, location='json') | |||
| parser.add_argument("role", type=str, required=True, location="json") | |||
| args = parser.parse_args() | |||
| new_role = args['role'] | |||
| new_role = args["role"] | |||
| if not TenantAccountRole.is_valid_role(new_role): | |||
| return {'code': 'invalid-role', 'message': 'Invalid role'}, 400 | |||
| return {"code": "invalid-role", "message": "Invalid role"}, 400 | |||
| member = db.session.get(Account, str(member_id)) | |||
| if not member: | |||
| @@ -128,7 +126,7 @@ class MemberUpdateRoleApi(Resource): | |||
| # todo: 403 | |||
| return {'result': 'success'} | |||
| return {"result": "success"} | |||
| class DatasetOperatorMemberListApi(Resource): | |||
| @@ -140,11 +138,11 @@ class DatasetOperatorMemberListApi(Resource): | |||
| @marshal_with(account_with_role_list_fields) | |||
| def get(self): | |||
| members = TenantService.get_dataset_operator_members(current_user.current_tenant) | |||
| return {'result': 'success', 'accounts': members}, 200 | |||
| return {"result": "success", "accounts": members}, 200 | |||
| api.add_resource(MemberListApi, '/workspaces/current/members') | |||
| api.add_resource(MemberInviteEmailApi, '/workspaces/current/members/invite-email') | |||
| api.add_resource(MemberCancelInviteApi, '/workspaces/current/members/<uuid:member_id>') | |||
| api.add_resource(MemberUpdateRoleApi, '/workspaces/current/members/<uuid:member_id>/update-role') | |||
| api.add_resource(DatasetOperatorMemberListApi, '/workspaces/current/dataset-operators') | |||
| api.add_resource(MemberListApi, "/workspaces/current/members") | |||
| api.add_resource(MemberInviteEmailApi, "/workspaces/current/members/invite-email") | |||
| api.add_resource(MemberCancelInviteApi, "/workspaces/current/members/<uuid:member_id>") | |||
| api.add_resource(MemberUpdateRoleApi, "/workspaces/current/members/<uuid:member_id>/update-role") | |||
| api.add_resource(DatasetOperatorMemberListApi, "/workspaces/current/dataset-operators") | |||
| @@ -17,7 +17,6 @@ from services.model_provider_service import ModelProviderService | |||
| class ModelProviderListApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @@ -25,21 +24,23 @@ class ModelProviderListApi(Resource): | |||
| tenant_id = current_user.current_tenant_id | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('model_type', type=str, required=False, nullable=True, | |||
| choices=[mt.value for mt in ModelType], location='args') | |||
| parser.add_argument( | |||
| "model_type", | |||
| type=str, | |||
| required=False, | |||
| nullable=True, | |||
| choices=[mt.value for mt in ModelType], | |||
| location="args", | |||
| ) | |||
| args = parser.parse_args() | |||
| model_provider_service = ModelProviderService() | |||
| provider_list = model_provider_service.get_provider_list( | |||
| tenant_id=tenant_id, | |||
| model_type=args.get('model_type') | |||
| ) | |||
| provider_list = model_provider_service.get_provider_list(tenant_id=tenant_id, model_type=args.get("model_type")) | |||
| return jsonable_encoder({"data": provider_list}) | |||
| class ModelProviderCredentialApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @@ -47,25 +48,18 @@ class ModelProviderCredentialApi(Resource): | |||
| tenant_id = current_user.current_tenant_id | |||
| model_provider_service = ModelProviderService() | |||
| credentials = model_provider_service.get_provider_credentials( | |||
| tenant_id=tenant_id, | |||
| provider=provider | |||
| ) | |||
| credentials = model_provider_service.get_provider_credentials(tenant_id=tenant_id, provider=provider) | |||
| return { | |||
| "credentials": credentials | |||
| } | |||
| return {"credentials": credentials} | |||
| class ModelProviderValidateApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| def post(self, provider: str): | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json') | |||
| parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") | |||
| args = parser.parse_args() | |||
| tenant_id = current_user.current_tenant_id | |||
| @@ -77,24 +71,21 @@ class ModelProviderValidateApi(Resource): | |||
| try: | |||
| model_provider_service.provider_credentials_validate( | |||
| tenant_id=tenant_id, | |||
| provider=provider, | |||
| credentials=args['credentials'] | |||
| tenant_id=tenant_id, provider=provider, credentials=args["credentials"] | |||
| ) | |||
| except CredentialsValidateFailedError as ex: | |||
| result = False | |||
| error = str(ex) | |||
| response = {'result': 'success' if result else 'error'} | |||
| response = {"result": "success" if result else "error"} | |||
| if not result: | |||
| response['error'] = error | |||
| response["error"] = error | |||
| return response | |||
| class ModelProviderApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @@ -103,21 +94,19 @@ class ModelProviderApi(Resource): | |||
| raise Forbidden() | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json') | |||
| parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") | |||
| args = parser.parse_args() | |||
| model_provider_service = ModelProviderService() | |||
| try: | |||
| model_provider_service.save_provider_credentials( | |||
| tenant_id=current_user.current_tenant_id, | |||
| provider=provider, | |||
| credentials=args['credentials'] | |||
| tenant_id=current_user.current_tenant_id, provider=provider, credentials=args["credentials"] | |||
| ) | |||
| except CredentialsValidateFailedError as ex: | |||
| raise ValueError(str(ex)) | |||
| return {'result': 'success'}, 201 | |||
| return {"result": "success"}, 201 | |||
| @setup_required | |||
| @login_required | |||
| @@ -127,12 +116,9 @@ class ModelProviderApi(Resource): | |||
| raise Forbidden() | |||
| model_provider_service = ModelProviderService() | |||
| model_provider_service.remove_provider_credentials( | |||
| tenant_id=current_user.current_tenant_id, | |||
| provider=provider | |||
| ) | |||
| model_provider_service.remove_provider_credentials(tenant_id=current_user.current_tenant_id, provider=provider) | |||
| return {'result': 'success'}, 204 | |||
| return {"result": "success"}, 204 | |||
| class ModelProviderIconApi(Resource): | |||
| @@ -146,16 +132,13 @@ class ModelProviderIconApi(Resource): | |||
| def get(self, provider: str, icon_type: str, lang: str): | |||
| model_provider_service = ModelProviderService() | |||
| icon, mimetype = model_provider_service.get_model_provider_icon( | |||
| provider=provider, | |||
| icon_type=icon_type, | |||
| lang=lang | |||
| provider=provider, icon_type=icon_type, lang=lang | |||
| ) | |||
| return send_file(io.BytesIO(icon), mimetype=mimetype) | |||
| class PreferredProviderTypeUpdateApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @@ -166,18 +149,22 @@ class PreferredProviderTypeUpdateApi(Resource): | |||
| tenant_id = current_user.current_tenant_id | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('preferred_provider_type', type=str, required=True, nullable=False, | |||
| choices=['system', 'custom'], location='json') | |||
| parser.add_argument( | |||
| "preferred_provider_type", | |||
| type=str, | |||
| required=True, | |||
| nullable=False, | |||
| choices=["system", "custom"], | |||
| location="json", | |||
| ) | |||
| args = parser.parse_args() | |||
| model_provider_service = ModelProviderService() | |||
| model_provider_service.switch_preferred_provider( | |||
| tenant_id=tenant_id, | |||
| provider=provider, | |||
| preferred_provider_type=args['preferred_provider_type'] | |||
| tenant_id=tenant_id, provider=provider, preferred_provider_type=args["preferred_provider_type"] | |||
| ) | |||
| return {'result': 'success'} | |||
| return {"result": "success"} | |||
| class ModelProviderPaymentCheckoutUrlApi(Resource): | |||
| @@ -185,13 +172,15 @@ class ModelProviderPaymentCheckoutUrlApi(Resource): | |||
| @login_required | |||
| @account_initialization_required | |||
| def get(self, provider: str): | |||
| if provider != 'anthropic': | |||
| raise ValueError(f'provider name {provider} is invalid') | |||
| if provider != "anthropic": | |||
| raise ValueError(f"provider name {provider} is invalid") | |||
| BillingService.is_tenant_owner_or_admin(current_user) | |||
| data = BillingService.get_model_provider_payment_link(provider_name=provider, | |||
| tenant_id=current_user.current_tenant_id, | |||
| account_id=current_user.id, | |||
| prefilled_email=current_user.email) | |||
| data = BillingService.get_model_provider_payment_link( | |||
| provider_name=provider, | |||
| tenant_id=current_user.current_tenant_id, | |||
| account_id=current_user.id, | |||
| prefilled_email=current_user.email, | |||
| ) | |||
| return data | |||
| @@ -201,10 +190,7 @@ class ModelProviderFreeQuotaSubmitApi(Resource): | |||
| @account_initialization_required | |||
| def post(self, provider: str): | |||
| model_provider_service = ModelProviderService() | |||
| result = model_provider_service.free_quota_submit( | |||
| tenant_id=current_user.current_tenant_id, | |||
| provider=provider | |||
| ) | |||
| result = model_provider_service.free_quota_submit(tenant_id=current_user.current_tenant_id, provider=provider) | |||
| return result | |||
| @@ -215,32 +201,36 @@ class ModelProviderFreeQuotaQualificationVerifyApi(Resource): | |||
| @account_initialization_required | |||
| def get(self, provider: str): | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('token', type=str, required=False, nullable=True, location='args') | |||
| parser.add_argument("token", type=str, required=False, nullable=True, location="args") | |||
| args = parser.parse_args() | |||
| model_provider_service = ModelProviderService() | |||
| result = model_provider_service.free_quota_qualification_verify( | |||
| tenant_id=current_user.current_tenant_id, | |||
| provider=provider, | |||
| token=args['token'] | |||
| tenant_id=current_user.current_tenant_id, provider=provider, token=args["token"] | |||
| ) | |||
| return result | |||
| api.add_resource(ModelProviderListApi, '/workspaces/current/model-providers') | |||
| api.add_resource(ModelProviderCredentialApi, '/workspaces/current/model-providers/<string:provider>/credentials') | |||
| api.add_resource(ModelProviderValidateApi, '/workspaces/current/model-providers/<string:provider>/credentials/validate') | |||
| api.add_resource(ModelProviderApi, '/workspaces/current/model-providers/<string:provider>') | |||
| api.add_resource(ModelProviderIconApi, '/workspaces/current/model-providers/<string:provider>/' | |||
| '<string:icon_type>/<string:lang>') | |||
| api.add_resource(PreferredProviderTypeUpdateApi, | |||
| '/workspaces/current/model-providers/<string:provider>/preferred-provider-type') | |||
| api.add_resource(ModelProviderPaymentCheckoutUrlApi, | |||
| '/workspaces/current/model-providers/<string:provider>/checkout-url') | |||
| api.add_resource(ModelProviderFreeQuotaSubmitApi, | |||
| '/workspaces/current/model-providers/<string:provider>/free-quota-submit') | |||
| api.add_resource(ModelProviderFreeQuotaQualificationVerifyApi, | |||
| '/workspaces/current/model-providers/<string:provider>/free-quota-qualification-verify') | |||
| api.add_resource(ModelProviderListApi, "/workspaces/current/model-providers") | |||
| api.add_resource(ModelProviderCredentialApi, "/workspaces/current/model-providers/<string:provider>/credentials") | |||
| api.add_resource(ModelProviderValidateApi, "/workspaces/current/model-providers/<string:provider>/credentials/validate") | |||
| api.add_resource(ModelProviderApi, "/workspaces/current/model-providers/<string:provider>") | |||
| api.add_resource( | |||
| ModelProviderIconApi, "/workspaces/current/model-providers/<string:provider>/" "<string:icon_type>/<string:lang>" | |||
| ) | |||
| api.add_resource( | |||
| PreferredProviderTypeUpdateApi, "/workspaces/current/model-providers/<string:provider>/preferred-provider-type" | |||
| ) | |||
| api.add_resource( | |||
| ModelProviderPaymentCheckoutUrlApi, "/workspaces/current/model-providers/<string:provider>/checkout-url" | |||
| ) | |||
| api.add_resource( | |||
| ModelProviderFreeQuotaSubmitApi, "/workspaces/current/model-providers/<string:provider>/free-quota-submit" | |||
| ) | |||
| api.add_resource( | |||
| ModelProviderFreeQuotaQualificationVerifyApi, | |||
| "/workspaces/current/model-providers/<string:provider>/free-quota-qualification-verify", | |||
| ) | |||
| @@ -16,27 +16,29 @@ from services.model_provider_service import ModelProviderService | |||
| class DefaultModelApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| def get(self): | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('model_type', type=str, required=True, nullable=False, | |||
| choices=[mt.value for mt in ModelType], location='args') | |||
| parser.add_argument( | |||
| "model_type", | |||
| type=str, | |||
| required=True, | |||
| nullable=False, | |||
| choices=[mt.value for mt in ModelType], | |||
| location="args", | |||
| ) | |||
| args = parser.parse_args() | |||
| tenant_id = current_user.current_tenant_id | |||
| model_provider_service = ModelProviderService() | |||
| default_model_entity = model_provider_service.get_default_model_of_model_type( | |||
| tenant_id=tenant_id, | |||
| model_type=args['model_type'] | |||
| tenant_id=tenant_id, model_type=args["model_type"] | |||
| ) | |||
| return jsonable_encoder({ | |||
| "data": default_model_entity | |||
| }) | |||
| return jsonable_encoder({"data": default_model_entity}) | |||
| @setup_required | |||
| @login_required | |||
| @@ -44,40 +46,39 @@ class DefaultModelApi(Resource): | |||
| def post(self): | |||
| if not current_user.is_admin_or_owner: | |||
| raise Forbidden() | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('model_settings', type=list, required=True, nullable=False, location='json') | |||
| parser.add_argument("model_settings", type=list, required=True, nullable=False, location="json") | |||
| args = parser.parse_args() | |||
| tenant_id = current_user.current_tenant_id | |||
| model_provider_service = ModelProviderService() | |||
| model_settings = args['model_settings'] | |||
| model_settings = args["model_settings"] | |||
| for model_setting in model_settings: | |||
| if 'model_type' not in model_setting or model_setting['model_type'] not in [mt.value for mt in ModelType]: | |||
| raise ValueError('invalid model type') | |||
| if "model_type" not in model_setting or model_setting["model_type"] not in [mt.value for mt in ModelType]: | |||
| raise ValueError("invalid model type") | |||
| if 'provider' not in model_setting: | |||
| if "provider" not in model_setting: | |||
| continue | |||
| if 'model' not in model_setting: | |||
| raise ValueError('invalid model') | |||
| if "model" not in model_setting: | |||
| raise ValueError("invalid model") | |||
| try: | |||
| model_provider_service.update_default_model_of_model_type( | |||
| tenant_id=tenant_id, | |||
| model_type=model_setting['model_type'], | |||
| provider=model_setting['provider'], | |||
| model=model_setting['model'] | |||
| model_type=model_setting["model_type"], | |||
| provider=model_setting["provider"], | |||
| model=model_setting["model"], | |||
| ) | |||
| except Exception: | |||
| logging.warning(f"{model_setting['model_type']} save error") | |||
| return {'result': 'success'} | |||
| return {"result": "success"} | |||
| class ModelProviderModelApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @@ -85,14 +86,9 @@ class ModelProviderModelApi(Resource): | |||
| tenant_id = current_user.current_tenant_id | |||
| model_provider_service = ModelProviderService() | |||
| models = model_provider_service.get_models_by_provider( | |||
| tenant_id=tenant_id, | |||
| provider=provider | |||
| ) | |||
| models = model_provider_service.get_models_by_provider(tenant_id=tenant_id, provider=provider) | |||
| return jsonable_encoder({ | |||
| "data": models | |||
| }) | |||
| return jsonable_encoder({"data": models}) | |||
| @setup_required | |||
| @login_required | |||
| @@ -104,62 +100,66 @@ class ModelProviderModelApi(Resource): | |||
| tenant_id = current_user.current_tenant_id | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('model', type=str, required=True, nullable=False, location='json') | |||
| parser.add_argument('model_type', type=str, required=True, nullable=False, | |||
| choices=[mt.value for mt in ModelType], location='json') | |||
| parser.add_argument('credentials', type=dict, required=False, nullable=True, location='json') | |||
| parser.add_argument('load_balancing', type=dict, required=False, nullable=True, location='json') | |||
| parser.add_argument('config_from', type=str, required=False, nullable=True, location='json') | |||
| parser.add_argument("model", type=str, required=True, nullable=False, location="json") | |||
| parser.add_argument( | |||
| "model_type", | |||
| type=str, | |||
| required=True, | |||
| nullable=False, | |||
| choices=[mt.value for mt in ModelType], | |||
| location="json", | |||
| ) | |||
| parser.add_argument("credentials", type=dict, required=False, nullable=True, location="json") | |||
| parser.add_argument("load_balancing", type=dict, required=False, nullable=True, location="json") | |||
| parser.add_argument("config_from", type=str, required=False, nullable=True, location="json") | |||
| args = parser.parse_args() | |||
| model_load_balancing_service = ModelLoadBalancingService() | |||
| if ('load_balancing' in args and args['load_balancing'] and | |||
| 'enabled' in args['load_balancing'] and args['load_balancing']['enabled']): | |||
| if 'configs' not in args['load_balancing']: | |||
| raise ValueError('invalid load balancing configs') | |||
| if ( | |||
| "load_balancing" in args | |||
| and args["load_balancing"] | |||
| and "enabled" in args["load_balancing"] | |||
| and args["load_balancing"]["enabled"] | |||
| ): | |||
| if "configs" not in args["load_balancing"]: | |||
| raise ValueError("invalid load balancing configs") | |||
| # save load balancing configs | |||
| model_load_balancing_service.update_load_balancing_configs( | |||
| tenant_id=tenant_id, | |||
| provider=provider, | |||
| model=args['model'], | |||
| model_type=args['model_type'], | |||
| configs=args['load_balancing']['configs'] | |||
| model=args["model"], | |||
| model_type=args["model_type"], | |||
| configs=args["load_balancing"]["configs"], | |||
| ) | |||
| # enable load balancing | |||
| model_load_balancing_service.enable_model_load_balancing( | |||
| tenant_id=tenant_id, | |||
| provider=provider, | |||
| model=args['model'], | |||
| model_type=args['model_type'] | |||
| tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"] | |||
| ) | |||
| else: | |||
| # disable load balancing | |||
| model_load_balancing_service.disable_model_load_balancing( | |||
| tenant_id=tenant_id, | |||
| provider=provider, | |||
| model=args['model'], | |||
| model_type=args['model_type'] | |||
| tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"] | |||
| ) | |||
| if args.get('config_from', '') != 'predefined-model': | |||
| if args.get("config_from", "") != "predefined-model": | |||
| model_provider_service = ModelProviderService() | |||
| try: | |||
| model_provider_service.save_model_credentials( | |||
| tenant_id=tenant_id, | |||
| provider=provider, | |||
| model=args['model'], | |||
| model_type=args['model_type'], | |||
| credentials=args['credentials'] | |||
| model=args["model"], | |||
| model_type=args["model_type"], | |||
| credentials=args["credentials"], | |||
| ) | |||
| except CredentialsValidateFailedError as ex: | |||
| logging.exception(f"save model credentials error: {ex}") | |||
| raise ValueError(str(ex)) | |||
| return {'result': 'success'}, 200 | |||
| return {"result": "success"}, 200 | |||
| @setup_required | |||
| @login_required | |||
| @@ -171,24 +171,26 @@ class ModelProviderModelApi(Resource): | |||
| tenant_id = current_user.current_tenant_id | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('model', type=str, required=True, nullable=False, location='json') | |||
| parser.add_argument('model_type', type=str, required=True, nullable=False, | |||
| choices=[mt.value for mt in ModelType], location='json') | |||
| parser.add_argument("model", type=str, required=True, nullable=False, location="json") | |||
| parser.add_argument( | |||
| "model_type", | |||
| type=str, | |||
| required=True, | |||
| nullable=False, | |||
| choices=[mt.value for mt in ModelType], | |||
| location="json", | |||
| ) | |||
| args = parser.parse_args() | |||
| model_provider_service = ModelProviderService() | |||
| model_provider_service.remove_model_credentials( | |||
| tenant_id=tenant_id, | |||
| provider=provider, | |||
| model=args['model'], | |||
| model_type=args['model_type'] | |||
| tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"] | |||
| ) | |||
| return {'result': 'success'}, 204 | |||
| return {"result": "success"}, 204 | |||
| class ModelProviderModelCredentialApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @@ -196,38 +198,34 @@ class ModelProviderModelCredentialApi(Resource): | |||
| tenant_id = current_user.current_tenant_id | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('model', type=str, required=True, nullable=False, location='args') | |||
| parser.add_argument('model_type', type=str, required=True, nullable=False, | |||
| choices=[mt.value for mt in ModelType], location='args') | |||
| parser.add_argument("model", type=str, required=True, nullable=False, location="args") | |||
| parser.add_argument( | |||
| "model_type", | |||
| type=str, | |||
| required=True, | |||
| nullable=False, | |||
| choices=[mt.value for mt in ModelType], | |||
| location="args", | |||
| ) | |||
| args = parser.parse_args() | |||
| model_provider_service = ModelProviderService() | |||
| credentials = model_provider_service.get_model_credentials( | |||
| tenant_id=tenant_id, | |||
| provider=provider, | |||
| model_type=args['model_type'], | |||
| model=args['model'] | |||
| tenant_id=tenant_id, provider=provider, model_type=args["model_type"], model=args["model"] | |||
| ) | |||
| model_load_balancing_service = ModelLoadBalancingService() | |||
| is_load_balancing_enabled, load_balancing_configs = model_load_balancing_service.get_load_balancing_configs( | |||
| tenant_id=tenant_id, | |||
| provider=provider, | |||
| model=args['model'], | |||
| model_type=args['model_type'] | |||
| tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"] | |||
| ) | |||
| return { | |||
| "credentials": credentials, | |||
| "load_balancing": { | |||
| "enabled": is_load_balancing_enabled, | |||
| "configs": load_balancing_configs | |||
| } | |||
| "load_balancing": {"enabled": is_load_balancing_enabled, "configs": load_balancing_configs}, | |||
| } | |||
| class ModelProviderModelEnableApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @@ -235,24 +233,26 @@ class ModelProviderModelEnableApi(Resource): | |||
| tenant_id = current_user.current_tenant_id | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('model', type=str, required=True, nullable=False, location='json') | |||
| parser.add_argument('model_type', type=str, required=True, nullable=False, | |||
| choices=[mt.value for mt in ModelType], location='json') | |||
| parser.add_argument("model", type=str, required=True, nullable=False, location="json") | |||
| parser.add_argument( | |||
| "model_type", | |||
| type=str, | |||
| required=True, | |||
| nullable=False, | |||
| choices=[mt.value for mt in ModelType], | |||
| location="json", | |||
| ) | |||
| args = parser.parse_args() | |||
| model_provider_service = ModelProviderService() | |||
| model_provider_service.enable_model( | |||
| tenant_id=tenant_id, | |||
| provider=provider, | |||
| model=args['model'], | |||
| model_type=args['model_type'] | |||
| tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"] | |||
| ) | |||
| return {'result': 'success'} | |||
| return {"result": "success"} | |||
| class ModelProviderModelDisableApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @@ -260,24 +260,26 @@ class ModelProviderModelDisableApi(Resource): | |||
| tenant_id = current_user.current_tenant_id | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('model', type=str, required=True, nullable=False, location='json') | |||
| parser.add_argument('model_type', type=str, required=True, nullable=False, | |||
| choices=[mt.value for mt in ModelType], location='json') | |||
| parser.add_argument("model", type=str, required=True, nullable=False, location="json") | |||
| parser.add_argument( | |||
| "model_type", | |||
| type=str, | |||
| required=True, | |||
| nullable=False, | |||
| choices=[mt.value for mt in ModelType], | |||
| location="json", | |||
| ) | |||
| args = parser.parse_args() | |||
| model_provider_service = ModelProviderService() | |||
| model_provider_service.disable_model( | |||
| tenant_id=tenant_id, | |||
| provider=provider, | |||
| model=args['model'], | |||
| model_type=args['model_type'] | |||
| tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"] | |||
| ) | |||
| return {'result': 'success'} | |||
| return {"result": "success"} | |||
| class ModelProviderModelValidateApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @@ -285,10 +287,16 @@ class ModelProviderModelValidateApi(Resource): | |||
| tenant_id = current_user.current_tenant_id | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('model', type=str, required=True, nullable=False, location='json') | |||
| parser.add_argument('model_type', type=str, required=True, nullable=False, | |||
| choices=[mt.value for mt in ModelType], location='json') | |||
| parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json') | |||
| parser.add_argument("model", type=str, required=True, nullable=False, location="json") | |||
| parser.add_argument( | |||
| "model_type", | |||
| type=str, | |||
| required=True, | |||
| nullable=False, | |||
| choices=[mt.value for mt in ModelType], | |||
| location="json", | |||
| ) | |||
| parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") | |||
| args = parser.parse_args() | |||
| model_provider_service = ModelProviderService() | |||
| @@ -300,48 +308,42 @@ class ModelProviderModelValidateApi(Resource): | |||
| model_provider_service.model_credentials_validate( | |||
| tenant_id=tenant_id, | |||
| provider=provider, | |||
| model=args['model'], | |||
| model_type=args['model_type'], | |||
| credentials=args['credentials'] | |||
| model=args["model"], | |||
| model_type=args["model_type"], | |||
| credentials=args["credentials"], | |||
| ) | |||
| except CredentialsValidateFailedError as ex: | |||
| result = False | |||
| error = str(ex) | |||
| response = {'result': 'success' if result else 'error'} | |||
| response = {"result": "success" if result else "error"} | |||
| if not result: | |||
| response['error'] = error | |||
| response["error"] = error | |||
| return response | |||
| class ModelProviderModelParameterRuleApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| def get(self, provider: str): | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('model', type=str, required=True, nullable=False, location='args') | |||
| parser.add_argument("model", type=str, required=True, nullable=False, location="args") | |||
| args = parser.parse_args() | |||
| tenant_id = current_user.current_tenant_id | |||
| model_provider_service = ModelProviderService() | |||
| parameter_rules = model_provider_service.get_model_parameter_rules( | |||
| tenant_id=tenant_id, | |||
| provider=provider, | |||
| model=args['model'] | |||
| tenant_id=tenant_id, provider=provider, model=args["model"] | |||
| ) | |||
| return jsonable_encoder({ | |||
| "data": parameter_rules | |||
| }) | |||
| return jsonable_encoder({"data": parameter_rules}) | |||
| class ModelProviderAvailableModelApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @@ -349,27 +351,31 @@ class ModelProviderAvailableModelApi(Resource): | |||
| tenant_id = current_user.current_tenant_id | |||
| model_provider_service = ModelProviderService() | |||
| models = model_provider_service.get_models_by_model_type( | |||
| tenant_id=tenant_id, | |||
| model_type=model_type | |||
| ) | |||
| return jsonable_encoder({ | |||
| "data": models | |||
| }) | |||
| api.add_resource(ModelProviderModelApi, '/workspaces/current/model-providers/<string:provider>/models') | |||
| api.add_resource(ModelProviderModelEnableApi, '/workspaces/current/model-providers/<string:provider>/models/enable', | |||
| endpoint='model-provider-model-enable') | |||
| api.add_resource(ModelProviderModelDisableApi, '/workspaces/current/model-providers/<string:provider>/models/disable', | |||
| endpoint='model-provider-model-disable') | |||
| api.add_resource(ModelProviderModelCredentialApi, | |||
| '/workspaces/current/model-providers/<string:provider>/models/credentials') | |||
| api.add_resource(ModelProviderModelValidateApi, | |||
| '/workspaces/current/model-providers/<string:provider>/models/credentials/validate') | |||
| api.add_resource(ModelProviderModelParameterRuleApi, | |||
| '/workspaces/current/model-providers/<string:provider>/models/parameter-rules') | |||
| api.add_resource(ModelProviderAvailableModelApi, '/workspaces/current/models/model-types/<string:model_type>') | |||
| api.add_resource(DefaultModelApi, '/workspaces/current/default-model') | |||
| models = model_provider_service.get_models_by_model_type(tenant_id=tenant_id, model_type=model_type) | |||
| return jsonable_encoder({"data": models}) | |||
| api.add_resource(ModelProviderModelApi, "/workspaces/current/model-providers/<string:provider>/models") | |||
| api.add_resource( | |||
| ModelProviderModelEnableApi, | |||
| "/workspaces/current/model-providers/<string:provider>/models/enable", | |||
| endpoint="model-provider-model-enable", | |||
| ) | |||
| api.add_resource( | |||
| ModelProviderModelDisableApi, | |||
| "/workspaces/current/model-providers/<string:provider>/models/disable", | |||
| endpoint="model-provider-model-disable", | |||
| ) | |||
| api.add_resource( | |||
| ModelProviderModelCredentialApi, "/workspaces/current/model-providers/<string:provider>/models/credentials" | |||
| ) | |||
| api.add_resource( | |||
| ModelProviderModelValidateApi, "/workspaces/current/model-providers/<string:provider>/models/credentials/validate" | |||
| ) | |||
| api.add_resource( | |||
| ModelProviderModelParameterRuleApi, "/workspaces/current/model-providers/<string:provider>/models/parameter-rules" | |||
| ) | |||
| api.add_resource(ModelProviderAvailableModelApi, "/workspaces/current/models/model-types/<string:model_type>") | |||
| api.add_resource(DefaultModelApi, "/workspaces/current/default-model") | |||
| @@ -28,10 +28,18 @@ class ToolProviderListApi(Resource): | |||
| tenant_id = current_user.current_tenant_id | |||
| req = reqparse.RequestParser() | |||
| req.add_argument('type', type=str, choices=['builtin', 'model', 'api', 'workflow'], required=False, nullable=True, location='args') | |||
| req.add_argument( | |||
| "type", | |||
| type=str, | |||
| choices=["builtin", "model", "api", "workflow"], | |||
| required=False, | |||
| nullable=True, | |||
| location="args", | |||
| ) | |||
| args = req.parse_args() | |||
| return ToolCommonService.list_tool_providers(user_id, tenant_id, args.get('type', None)) | |||
| return ToolCommonService.list_tool_providers(user_id, tenant_id, args.get("type", None)) | |||
| class ToolBuiltinProviderListToolsApi(Resource): | |||
| @setup_required | |||
| @@ -41,11 +49,14 @@ class ToolBuiltinProviderListToolsApi(Resource): | |||
| user_id = current_user.id | |||
| tenant_id = current_user.current_tenant_id | |||
| return jsonable_encoder(BuiltinToolManageService.list_builtin_tool_provider_tools( | |||
| user_id, | |||
| tenant_id, | |||
| provider, | |||
| )) | |||
| return jsonable_encoder( | |||
| BuiltinToolManageService.list_builtin_tool_provider_tools( | |||
| user_id, | |||
| tenant_id, | |||
| provider, | |||
| ) | |||
| ) | |||
| class ToolBuiltinProviderDeleteApi(Resource): | |||
| @setup_required | |||
| @@ -54,7 +65,7 @@ class ToolBuiltinProviderDeleteApi(Resource): | |||
| def post(self, provider): | |||
| if not current_user.is_admin_or_owner: | |||
| raise Forbidden() | |||
| user_id = current_user.id | |||
| tenant_id = current_user.current_tenant_id | |||
| @@ -63,7 +74,8 @@ class ToolBuiltinProviderDeleteApi(Resource): | |||
| tenant_id, | |||
| provider, | |||
| ) | |||
| class ToolBuiltinProviderUpdateApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @@ -71,12 +83,12 @@ class ToolBuiltinProviderUpdateApi(Resource): | |||
| def post(self, provider): | |||
| if not current_user.is_admin_or_owner: | |||
| raise Forbidden() | |||
| user_id = current_user.id | |||
| tenant_id = current_user.current_tenant_id | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json') | |||
| parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") | |||
| args = parser.parse_args() | |||
| @@ -84,9 +96,10 @@ class ToolBuiltinProviderUpdateApi(Resource): | |||
| user_id, | |||
| tenant_id, | |||
| provider, | |||
| args['credentials'], | |||
| args["credentials"], | |||
| ) | |||
| class ToolBuiltinProviderGetCredentialsApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @@ -101,6 +114,7 @@ class ToolBuiltinProviderGetCredentialsApi(Resource): | |||
| provider, | |||
| ) | |||
| class ToolBuiltinProviderIconApi(Resource): | |||
| @setup_required | |||
| def get(self, provider): | |||
| @@ -108,6 +122,7 @@ class ToolBuiltinProviderIconApi(Resource): | |||
| icon_cache_max_age = dify_config.TOOL_ICON_CACHE_MAX_AGE | |||
| return send_file(io.BytesIO(icon_bytes), mimetype=mimetype, max_age=icon_cache_max_age) | |||
| class ToolApiProviderAddApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @@ -115,35 +130,36 @@ class ToolApiProviderAddApi(Resource): | |||
| def post(self): | |||
| if not current_user.is_admin_or_owner: | |||
| raise Forbidden() | |||
| user_id = current_user.id | |||
| tenant_id = current_user.current_tenant_id | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json') | |||
| parser.add_argument('schema_type', type=str, required=True, nullable=False, location='json') | |||
| parser.add_argument('schema', type=str, required=True, nullable=False, location='json') | |||
| parser.add_argument('provider', type=str, required=True, nullable=False, location='json') | |||
| parser.add_argument('icon', type=dict, required=True, nullable=False, location='json') | |||
| parser.add_argument('privacy_policy', type=str, required=False, nullable=True, location='json') | |||
| parser.add_argument('labels', type=list[str], required=False, nullable=True, location='json', default=[]) | |||
| parser.add_argument('custom_disclaimer', type=str, required=False, nullable=True, location='json') | |||
| parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") | |||
| parser.add_argument("schema_type", type=str, required=True, nullable=False, location="json") | |||
| parser.add_argument("schema", type=str, required=True, nullable=False, location="json") | |||
| parser.add_argument("provider", type=str, required=True, nullable=False, location="json") | |||
| parser.add_argument("icon", type=dict, required=True, nullable=False, location="json") | |||
| parser.add_argument("privacy_policy", type=str, required=False, nullable=True, location="json") | |||
| parser.add_argument("labels", type=list[str], required=False, nullable=True, location="json", default=[]) | |||
| parser.add_argument("custom_disclaimer", type=str, required=False, nullable=True, location="json") | |||
| args = parser.parse_args() | |||
| return ApiToolManageService.create_api_tool_provider( | |||
| user_id, | |||
| tenant_id, | |||
| args['provider'], | |||
| args['icon'], | |||
| args['credentials'], | |||
| args['schema_type'], | |||
| args['schema'], | |||
| args.get('privacy_policy', ''), | |||
| args.get('custom_disclaimer', ''), | |||
| args.get('labels', []), | |||
| args["provider"], | |||
| args["icon"], | |||
| args["credentials"], | |||
| args["schema_type"], | |||
| args["schema"], | |||
| args.get("privacy_policy", ""), | |||
| args.get("custom_disclaimer", ""), | |||
| args.get("labels", []), | |||
| ) | |||
| class ToolApiProviderGetRemoteSchemaApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @@ -151,16 +167,17 @@ class ToolApiProviderGetRemoteSchemaApi(Resource): | |||
| def get(self): | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('url', type=str, required=True, nullable=False, location='args') | |||
| parser.add_argument("url", type=str, required=True, nullable=False, location="args") | |||
| args = parser.parse_args() | |||
| return ApiToolManageService.get_api_tool_provider_remote_schema( | |||
| current_user.id, | |||
| current_user.current_tenant_id, | |||
| args['url'], | |||
| args["url"], | |||
| ) | |||
| class ToolApiProviderListToolsApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @@ -171,15 +188,18 @@ class ToolApiProviderListToolsApi(Resource): | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('provider', type=str, required=True, nullable=False, location='args') | |||
| parser.add_argument("provider", type=str, required=True, nullable=False, location="args") | |||
| args = parser.parse_args() | |||
| return jsonable_encoder(ApiToolManageService.list_api_tool_provider_tools( | |||
| user_id, | |||
| tenant_id, | |||
| args['provider'], | |||
| )) | |||
| return jsonable_encoder( | |||
| ApiToolManageService.list_api_tool_provider_tools( | |||
| user_id, | |||
| tenant_id, | |||
| args["provider"], | |||
| ) | |||
| ) | |||
| class ToolApiProviderUpdateApi(Resource): | |||
| @setup_required | |||
| @@ -188,37 +208,38 @@ class ToolApiProviderUpdateApi(Resource): | |||
| def post(self): | |||
| if not current_user.is_admin_or_owner: | |||
| raise Forbidden() | |||
| user_id = current_user.id | |||
| tenant_id = current_user.current_tenant_id | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json') | |||
| parser.add_argument('schema_type', type=str, required=True, nullable=False, location='json') | |||
| parser.add_argument('schema', type=str, required=True, nullable=False, location='json') | |||
| parser.add_argument('provider', type=str, required=True, nullable=False, location='json') | |||
| parser.add_argument('original_provider', type=str, required=True, nullable=False, location='json') | |||
| parser.add_argument('icon', type=dict, required=True, nullable=False, location='json') | |||
| parser.add_argument('privacy_policy', type=str, required=True, nullable=True, location='json') | |||
| parser.add_argument('labels', type=list[str], required=False, nullable=True, location='json') | |||
| parser.add_argument('custom_disclaimer', type=str, required=True, nullable=True, location='json') | |||
| parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") | |||
| parser.add_argument("schema_type", type=str, required=True, nullable=False, location="json") | |||
| parser.add_argument("schema", type=str, required=True, nullable=False, location="json") | |||
| parser.add_argument("provider", type=str, required=True, nullable=False, location="json") | |||
| parser.add_argument("original_provider", type=str, required=True, nullable=False, location="json") | |||
| parser.add_argument("icon", type=dict, required=True, nullable=False, location="json") | |||
| parser.add_argument("privacy_policy", type=str, required=True, nullable=True, location="json") | |||
| parser.add_argument("labels", type=list[str], required=False, nullable=True, location="json") | |||
| parser.add_argument("custom_disclaimer", type=str, required=True, nullable=True, location="json") | |||
| args = parser.parse_args() | |||
| return ApiToolManageService.update_api_tool_provider( | |||
| user_id, | |||
| tenant_id, | |||
| args['provider'], | |||
| args['original_provider'], | |||
| args['icon'], | |||
| args['credentials'], | |||
| args['schema_type'], | |||
| args['schema'], | |||
| args['privacy_policy'], | |||
| args['custom_disclaimer'], | |||
| args.get('labels', []), | |||
| args["provider"], | |||
| args["original_provider"], | |||
| args["icon"], | |||
| args["credentials"], | |||
| args["schema_type"], | |||
| args["schema"], | |||
| args["privacy_policy"], | |||
| args["custom_disclaimer"], | |||
| args.get("labels", []), | |||
| ) | |||
| class ToolApiProviderDeleteApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @@ -226,22 +247,23 @@ class ToolApiProviderDeleteApi(Resource): | |||
| def post(self): | |||
| if not current_user.is_admin_or_owner: | |||
| raise Forbidden() | |||
| user_id = current_user.id | |||
| tenant_id = current_user.current_tenant_id | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('provider', type=str, required=True, nullable=False, location='json') | |||
| parser.add_argument("provider", type=str, required=True, nullable=False, location="json") | |||
| args = parser.parse_args() | |||
| return ApiToolManageService.delete_api_tool_provider( | |||
| user_id, | |||
| tenant_id, | |||
| args['provider'], | |||
| args["provider"], | |||
| ) | |||
| class ToolApiProviderGetApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @@ -252,16 +274,17 @@ class ToolApiProviderGetApi(Resource): | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('provider', type=str, required=True, nullable=False, location='args') | |||
| parser.add_argument("provider", type=str, required=True, nullable=False, location="args") | |||
| args = parser.parse_args() | |||
| return ApiToolManageService.get_api_tool_provider( | |||
| user_id, | |||
| tenant_id, | |||
| args['provider'], | |||
| args["provider"], | |||
| ) | |||
| class ToolBuiltinProviderCredentialsSchemaApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @@ -269,6 +292,7 @@ class ToolBuiltinProviderCredentialsSchemaApi(Resource): | |||
| def get(self, provider): | |||
| return BuiltinToolManageService.list_builtin_provider_credentials_schema(provider) | |||
| class ToolApiProviderSchemaApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @@ -276,14 +300,15 @@ class ToolApiProviderSchemaApi(Resource): | |||
| def post(self): | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('schema', type=str, required=True, nullable=False, location='json') | |||
| parser.add_argument("schema", type=str, required=True, nullable=False, location="json") | |||
| args = parser.parse_args() | |||
| return ApiToolManageService.parser_api_schema( | |||
| schema=args['schema'], | |||
| schema=args["schema"], | |||
| ) | |||
| class ToolApiProviderPreviousTestApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @@ -291,25 +316,26 @@ class ToolApiProviderPreviousTestApi(Resource): | |||
| def post(self): | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('tool_name', type=str, required=True, nullable=False, location='json') | |||
| parser.add_argument('provider_name', type=str, required=False, nullable=False, location='json') | |||
| parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json') | |||
| parser.add_argument('parameters', type=dict, required=True, nullable=False, location='json') | |||
| parser.add_argument('schema_type', type=str, required=True, nullable=False, location='json') | |||
| parser.add_argument('schema', type=str, required=True, nullable=False, location='json') | |||
| parser.add_argument("tool_name", type=str, required=True, nullable=False, location="json") | |||
| parser.add_argument("provider_name", type=str, required=False, nullable=False, location="json") | |||
| parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") | |||
| parser.add_argument("parameters", type=dict, required=True, nullable=False, location="json") | |||
| parser.add_argument("schema_type", type=str, required=True, nullable=False, location="json") | |||
| parser.add_argument("schema", type=str, required=True, nullable=False, location="json") | |||
| args = parser.parse_args() | |||
| return ApiToolManageService.test_api_tool_preview( | |||
| current_user.current_tenant_id, | |||
| args['provider_name'] if args['provider_name'] else '', | |||
| args['tool_name'], | |||
| args['credentials'], | |||
| args['parameters'], | |||
| args['schema_type'], | |||
| args['schema'], | |||
| args["provider_name"] if args["provider_name"] else "", | |||
| args["tool_name"], | |||
| args["credentials"], | |||
| args["parameters"], | |||
| args["schema_type"], | |||
| args["schema"], | |||
| ) | |||
| class ToolWorkflowProviderCreateApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @@ -317,35 +343,36 @@ class ToolWorkflowProviderCreateApi(Resource): | |||
| def post(self): | |||
| if not current_user.is_admin_or_owner: | |||
| raise Forbidden() | |||
| user_id = current_user.id | |||
| tenant_id = current_user.current_tenant_id | |||
| reqparser = reqparse.RequestParser() | |||
| reqparser.add_argument('workflow_app_id', type=uuid_value, required=True, nullable=False, location='json') | |||
| reqparser.add_argument('name', type=alphanumeric, required=True, nullable=False, location='json') | |||
| reqparser.add_argument('label', type=str, required=True, nullable=False, location='json') | |||
| reqparser.add_argument('description', type=str, required=True, nullable=False, location='json') | |||
| reqparser.add_argument('icon', type=dict, required=True, nullable=False, location='json') | |||
| reqparser.add_argument('parameters', type=list[dict], required=True, nullable=False, location='json') | |||
| reqparser.add_argument('privacy_policy', type=str, required=False, nullable=True, location='json', default='') | |||
| reqparser.add_argument('labels', type=list[str], required=False, nullable=True, location='json') | |||
| reqparser.add_argument("workflow_app_id", type=uuid_value, required=True, nullable=False, location="json") | |||
| reqparser.add_argument("name", type=alphanumeric, required=True, nullable=False, location="json") | |||
| reqparser.add_argument("label", type=str, required=True, nullable=False, location="json") | |||
| reqparser.add_argument("description", type=str, required=True, nullable=False, location="json") | |||
| reqparser.add_argument("icon", type=dict, required=True, nullable=False, location="json") | |||
| reqparser.add_argument("parameters", type=list[dict], required=True, nullable=False, location="json") | |||
| reqparser.add_argument("privacy_policy", type=str, required=False, nullable=True, location="json", default="") | |||
| reqparser.add_argument("labels", type=list[str], required=False, nullable=True, location="json") | |||
| args = reqparser.parse_args() | |||
| return WorkflowToolManageService.create_workflow_tool( | |||
| user_id, | |||
| tenant_id, | |||
| args['workflow_app_id'], | |||
| args['name'], | |||
| args['label'], | |||
| args['icon'], | |||
| args['description'], | |||
| args['parameters'], | |||
| args['privacy_policy'], | |||
| args.get('labels', []), | |||
| args["workflow_app_id"], | |||
| args["name"], | |||
| args["label"], | |||
| args["icon"], | |||
| args["description"], | |||
| args["parameters"], | |||
| args["privacy_policy"], | |||
| args.get("labels", []), | |||
| ) | |||
| class ToolWorkflowProviderUpdateApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @@ -353,38 +380,39 @@ class ToolWorkflowProviderUpdateApi(Resource): | |||
| def post(self): | |||
| if not current_user.is_admin_or_owner: | |||
| raise Forbidden() | |||
| user_id = current_user.id | |||
| tenant_id = current_user.current_tenant_id | |||
| reqparser = reqparse.RequestParser() | |||
| reqparser.add_argument('workflow_tool_id', type=uuid_value, required=True, nullable=False, location='json') | |||
| reqparser.add_argument('name', type=alphanumeric, required=True, nullable=False, location='json') | |||
| reqparser.add_argument('label', type=str, required=True, nullable=False, location='json') | |||
| reqparser.add_argument('description', type=str, required=True, nullable=False, location='json') | |||
| reqparser.add_argument('icon', type=dict, required=True, nullable=False, location='json') | |||
| reqparser.add_argument('parameters', type=list[dict], required=True, nullable=False, location='json') | |||
| reqparser.add_argument('privacy_policy', type=str, required=False, nullable=True, location='json', default='') | |||
| reqparser.add_argument('labels', type=list[str], required=False, nullable=True, location='json') | |||
| reqparser.add_argument("workflow_tool_id", type=uuid_value, required=True, nullable=False, location="json") | |||
| reqparser.add_argument("name", type=alphanumeric, required=True, nullable=False, location="json") | |||
| reqparser.add_argument("label", type=str, required=True, nullable=False, location="json") | |||
| reqparser.add_argument("description", type=str, required=True, nullable=False, location="json") | |||
| reqparser.add_argument("icon", type=dict, required=True, nullable=False, location="json") | |||
| reqparser.add_argument("parameters", type=list[dict], required=True, nullable=False, location="json") | |||
| reqparser.add_argument("privacy_policy", type=str, required=False, nullable=True, location="json", default="") | |||
| reqparser.add_argument("labels", type=list[str], required=False, nullable=True, location="json") | |||
| args = reqparser.parse_args() | |||
| if not args['workflow_tool_id']: | |||
| raise ValueError('incorrect workflow_tool_id') | |||
| if not args["workflow_tool_id"]: | |||
| raise ValueError("incorrect workflow_tool_id") | |||
| return WorkflowToolManageService.update_workflow_tool( | |||
| user_id, | |||
| tenant_id, | |||
| args['workflow_tool_id'], | |||
| args['name'], | |||
| args['label'], | |||
| args['icon'], | |||
| args['description'], | |||
| args['parameters'], | |||
| args['privacy_policy'], | |||
| args.get('labels', []), | |||
| args["workflow_tool_id"], | |||
| args["name"], | |||
| args["label"], | |||
| args["icon"], | |||
| args["description"], | |||
| args["parameters"], | |||
| args["privacy_policy"], | |||
| args.get("labels", []), | |||
| ) | |||
| class ToolWorkflowProviderDeleteApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @@ -392,21 +420,22 @@ class ToolWorkflowProviderDeleteApi(Resource): | |||
| def post(self): | |||
| if not current_user.is_admin_or_owner: | |||
| raise Forbidden() | |||
| user_id = current_user.id | |||
| tenant_id = current_user.current_tenant_id | |||
| reqparser = reqparse.RequestParser() | |||
| reqparser.add_argument('workflow_tool_id', type=uuid_value, required=True, nullable=False, location='json') | |||
| reqparser.add_argument("workflow_tool_id", type=uuid_value, required=True, nullable=False, location="json") | |||
| args = reqparser.parse_args() | |||
| return WorkflowToolManageService.delete_workflow_tool( | |||
| user_id, | |||
| tenant_id, | |||
| args['workflow_tool_id'], | |||
| args["workflow_tool_id"], | |||
| ) | |||
| class ToolWorkflowProviderGetApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @@ -416,28 +445,29 @@ class ToolWorkflowProviderGetApi(Resource): | |||
| tenant_id = current_user.current_tenant_id | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('workflow_tool_id', type=uuid_value, required=False, nullable=True, location='args') | |||
| parser.add_argument('workflow_app_id', type=uuid_value, required=False, nullable=True, location='args') | |||
| parser.add_argument("workflow_tool_id", type=uuid_value, required=False, nullable=True, location="args") | |||
| parser.add_argument("workflow_app_id", type=uuid_value, required=False, nullable=True, location="args") | |||
| args = parser.parse_args() | |||
| if args.get('workflow_tool_id'): | |||
| if args.get("workflow_tool_id"): | |||
| tool = WorkflowToolManageService.get_workflow_tool_by_tool_id( | |||
| user_id, | |||
| tenant_id, | |||
| args['workflow_tool_id'], | |||
| args["workflow_tool_id"], | |||
| ) | |||
| elif args.get('workflow_app_id'): | |||
| elif args.get("workflow_app_id"): | |||
| tool = WorkflowToolManageService.get_workflow_tool_by_app_id( | |||
| user_id, | |||
| tenant_id, | |||
| args['workflow_app_id'], | |||
| args["workflow_app_id"], | |||
| ) | |||
| else: | |||
| raise ValueError('incorrect workflow_tool_id or workflow_app_id') | |||
| raise ValueError("incorrect workflow_tool_id or workflow_app_id") | |||
| return jsonable_encoder(tool) | |||
| class ToolWorkflowProviderListToolApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @@ -447,15 +477,18 @@ class ToolWorkflowProviderListToolApi(Resource): | |||
| tenant_id = current_user.current_tenant_id | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('workflow_tool_id', type=uuid_value, required=True, nullable=False, location='args') | |||
| parser.add_argument("workflow_tool_id", type=uuid_value, required=True, nullable=False, location="args") | |||
| args = parser.parse_args() | |||
| return jsonable_encoder(WorkflowToolManageService.list_single_workflow_tools( | |||
| user_id, | |||
| tenant_id, | |||
| args['workflow_tool_id'], | |||
| )) | |||
| return jsonable_encoder( | |||
| WorkflowToolManageService.list_single_workflow_tools( | |||
| user_id, | |||
| tenant_id, | |||
| args["workflow_tool_id"], | |||
| ) | |||
| ) | |||
| class ToolBuiltinListApi(Resource): | |||
| @setup_required | |||
| @@ -465,11 +498,17 @@ class ToolBuiltinListApi(Resource): | |||
| user_id = current_user.id | |||
| tenant_id = current_user.current_tenant_id | |||
| return jsonable_encoder([provider.to_dict() for provider in BuiltinToolManageService.list_builtin_tools( | |||
| user_id, | |||
| tenant_id, | |||
| )]) | |||
| return jsonable_encoder( | |||
| [ | |||
| provider.to_dict() | |||
| for provider in BuiltinToolManageService.list_builtin_tools( | |||
| user_id, | |||
| tenant_id, | |||
| ) | |||
| ] | |||
| ) | |||
| class ToolApiListApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @@ -478,11 +517,17 @@ class ToolApiListApi(Resource): | |||
| user_id = current_user.id | |||
| tenant_id = current_user.current_tenant_id | |||
| return jsonable_encoder([provider.to_dict() for provider in ApiToolManageService.list_api_tools( | |||
| user_id, | |||
| tenant_id, | |||
| )]) | |||
| return jsonable_encoder( | |||
| [ | |||
| provider.to_dict() | |||
| for provider in ApiToolManageService.list_api_tools( | |||
| user_id, | |||
| tenant_id, | |||
| ) | |||
| ] | |||
| ) | |||
| class ToolWorkflowListApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @@ -491,11 +536,17 @@ class ToolWorkflowListApi(Resource): | |||
| user_id = current_user.id | |||
| tenant_id = current_user.current_tenant_id | |||
| return jsonable_encoder([provider.to_dict() for provider in WorkflowToolManageService.list_tenant_workflow_tools( | |||
| user_id, | |||
| tenant_id, | |||
| )]) | |||
| return jsonable_encoder( | |||
| [ | |||
| provider.to_dict() | |||
| for provider in WorkflowToolManageService.list_tenant_workflow_tools( | |||
| user_id, | |||
| tenant_id, | |||
| ) | |||
| ] | |||
| ) | |||
| class ToolLabelsApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @@ -503,36 +554,41 @@ class ToolLabelsApi(Resource): | |||
| def get(self): | |||
| return jsonable_encoder(ToolLabelsService.list_tool_labels()) | |||
| # tool provider | |||
| api.add_resource(ToolProviderListApi, '/workspaces/current/tool-providers') | |||
| api.add_resource(ToolProviderListApi, "/workspaces/current/tool-providers") | |||
| # builtin tool provider | |||
| api.add_resource(ToolBuiltinProviderListToolsApi, '/workspaces/current/tool-provider/builtin/<provider>/tools') | |||
| api.add_resource(ToolBuiltinProviderDeleteApi, '/workspaces/current/tool-provider/builtin/<provider>/delete') | |||
| api.add_resource(ToolBuiltinProviderUpdateApi, '/workspaces/current/tool-provider/builtin/<provider>/update') | |||
| api.add_resource(ToolBuiltinProviderGetCredentialsApi, '/workspaces/current/tool-provider/builtin/<provider>/credentials') | |||
| api.add_resource(ToolBuiltinProviderCredentialsSchemaApi, '/workspaces/current/tool-provider/builtin/<provider>/credentials_schema') | |||
| api.add_resource(ToolBuiltinProviderIconApi, '/workspaces/current/tool-provider/builtin/<provider>/icon') | |||
| api.add_resource(ToolBuiltinProviderListToolsApi, "/workspaces/current/tool-provider/builtin/<provider>/tools") | |||
| api.add_resource(ToolBuiltinProviderDeleteApi, "/workspaces/current/tool-provider/builtin/<provider>/delete") | |||
| api.add_resource(ToolBuiltinProviderUpdateApi, "/workspaces/current/tool-provider/builtin/<provider>/update") | |||
| api.add_resource( | |||
| ToolBuiltinProviderGetCredentialsApi, "/workspaces/current/tool-provider/builtin/<provider>/credentials" | |||
| ) | |||
| api.add_resource( | |||
| ToolBuiltinProviderCredentialsSchemaApi, "/workspaces/current/tool-provider/builtin/<provider>/credentials_schema" | |||
| ) | |||
| api.add_resource(ToolBuiltinProviderIconApi, "/workspaces/current/tool-provider/builtin/<provider>/icon") | |||
| # api tool provider | |||
| api.add_resource(ToolApiProviderAddApi, '/workspaces/current/tool-provider/api/add') | |||
| api.add_resource(ToolApiProviderGetRemoteSchemaApi, '/workspaces/current/tool-provider/api/remote') | |||
| api.add_resource(ToolApiProviderListToolsApi, '/workspaces/current/tool-provider/api/tools') | |||
| api.add_resource(ToolApiProviderUpdateApi, '/workspaces/current/tool-provider/api/update') | |||
| api.add_resource(ToolApiProviderDeleteApi, '/workspaces/current/tool-provider/api/delete') | |||
| api.add_resource(ToolApiProviderGetApi, '/workspaces/current/tool-provider/api/get') | |||
| api.add_resource(ToolApiProviderSchemaApi, '/workspaces/current/tool-provider/api/schema') | |||
| api.add_resource(ToolApiProviderPreviousTestApi, '/workspaces/current/tool-provider/api/test/pre') | |||
| api.add_resource(ToolApiProviderAddApi, "/workspaces/current/tool-provider/api/add") | |||
| api.add_resource(ToolApiProviderGetRemoteSchemaApi, "/workspaces/current/tool-provider/api/remote") | |||
| api.add_resource(ToolApiProviderListToolsApi, "/workspaces/current/tool-provider/api/tools") | |||
| api.add_resource(ToolApiProviderUpdateApi, "/workspaces/current/tool-provider/api/update") | |||
| api.add_resource(ToolApiProviderDeleteApi, "/workspaces/current/tool-provider/api/delete") | |||
| api.add_resource(ToolApiProviderGetApi, "/workspaces/current/tool-provider/api/get") | |||
| api.add_resource(ToolApiProviderSchemaApi, "/workspaces/current/tool-provider/api/schema") | |||
| api.add_resource(ToolApiProviderPreviousTestApi, "/workspaces/current/tool-provider/api/test/pre") | |||
| # workflow tool provider | |||
| api.add_resource(ToolWorkflowProviderCreateApi, '/workspaces/current/tool-provider/workflow/create') | |||
| api.add_resource(ToolWorkflowProviderUpdateApi, '/workspaces/current/tool-provider/workflow/update') | |||
| api.add_resource(ToolWorkflowProviderDeleteApi, '/workspaces/current/tool-provider/workflow/delete') | |||
| api.add_resource(ToolWorkflowProviderGetApi, '/workspaces/current/tool-provider/workflow/get') | |||
| api.add_resource(ToolWorkflowProviderListToolApi, '/workspaces/current/tool-provider/workflow/tools') | |||
| api.add_resource(ToolWorkflowProviderCreateApi, "/workspaces/current/tool-provider/workflow/create") | |||
| api.add_resource(ToolWorkflowProviderUpdateApi, "/workspaces/current/tool-provider/workflow/update") | |||
| api.add_resource(ToolWorkflowProviderDeleteApi, "/workspaces/current/tool-provider/workflow/delete") | |||
| api.add_resource(ToolWorkflowProviderGetApi, "/workspaces/current/tool-provider/workflow/get") | |||
| api.add_resource(ToolWorkflowProviderListToolApi, "/workspaces/current/tool-provider/workflow/tools") | |||
| api.add_resource(ToolBuiltinListApi, '/workspaces/current/tools/builtin') | |||
| api.add_resource(ToolApiListApi, '/workspaces/current/tools/api') | |||
| api.add_resource(ToolWorkflowListApi, '/workspaces/current/tools/workflow') | |||
| api.add_resource(ToolBuiltinListApi, "/workspaces/current/tools/builtin") | |||
| api.add_resource(ToolApiListApi, "/workspaces/current/tools/api") | |||
| api.add_resource(ToolWorkflowListApi, "/workspaces/current/tools/workflow") | |||
| api.add_resource(ToolLabelsApi, '/workspaces/current/tool-labels') | |||
| api.add_resource(ToolLabelsApi, "/workspaces/current/tool-labels") | |||
| @@ -26,39 +26,34 @@ from services.file_service import FileService | |||
| from services.workspace_service import WorkspaceService | |||
| provider_fields = { | |||
| 'provider_name': fields.String, | |||
| 'provider_type': fields.String, | |||
| 'is_valid': fields.Boolean, | |||
| 'token_is_set': fields.Boolean, | |||
| "provider_name": fields.String, | |||
| "provider_type": fields.String, | |||
| "is_valid": fields.Boolean, | |||
| "token_is_set": fields.Boolean, | |||
| } | |||
| tenant_fields = { | |||
| 'id': fields.String, | |||
| 'name': fields.String, | |||
| 'plan': fields.String, | |||
| 'status': fields.String, | |||
| 'created_at': TimestampField, | |||
| 'role': fields.String, | |||
| 'in_trial': fields.Boolean, | |||
| 'trial_end_reason': fields.String, | |||
| 'custom_config': fields.Raw(attribute='custom_config'), | |||
| "id": fields.String, | |||
| "name": fields.String, | |||
| "plan": fields.String, | |||
| "status": fields.String, | |||
| "created_at": TimestampField, | |||
| "role": fields.String, | |||
| "in_trial": fields.Boolean, | |||
| "trial_end_reason": fields.String, | |||
| "custom_config": fields.Raw(attribute="custom_config"), | |||
| } | |||
| tenants_fields = { | |||
| 'id': fields.String, | |||
| 'name': fields.String, | |||
| 'plan': fields.String, | |||
| 'status': fields.String, | |||
| 'created_at': TimestampField, | |||
| 'current': fields.Boolean | |||
| "id": fields.String, | |||
| "name": fields.String, | |||
| "plan": fields.String, | |||
| "status": fields.String, | |||
| "created_at": TimestampField, | |||
| "current": fields.Boolean, | |||
| } | |||
| workspace_fields = { | |||
| 'id': fields.String, | |||
| 'name': fields.String, | |||
| 'status': fields.String, | |||
| 'created_at': TimestampField | |||
| } | |||
| workspace_fields = {"id": fields.String, "name": fields.String, "status": fields.String, "created_at": TimestampField} | |||
| class TenantListApi(Resource): | |||
| @@ -71,7 +66,7 @@ class TenantListApi(Resource): | |||
| for tenant in tenants: | |||
| if tenant.id == current_user.current_tenant_id: | |||
| tenant.current = True # Set current=True for current tenant | |||
| return {'workspaces': marshal(tenants, tenants_fields)}, 200 | |||
| return {"workspaces": marshal(tenants, tenants_fields)}, 200 | |||
| class WorkspaceListApi(Resource): | |||
| @@ -79,31 +74,37 @@ class WorkspaceListApi(Resource): | |||
| @admin_required | |||
| def get(self): | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('page', type=inputs.int_range(1, 99999), required=False, default=1, location='args') | |||
| parser.add_argument('limit', type=inputs.int_range(1, 100), required=False, default=20, location='args') | |||
| parser.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args") | |||
| parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args") | |||
| args = parser.parse_args() | |||
| tenants = db.session.query(Tenant).order_by(Tenant.created_at.desc())\ | |||
| .paginate(page=args['page'], per_page=args['limit']) | |||
| tenants = ( | |||
| db.session.query(Tenant) | |||
| .order_by(Tenant.created_at.desc()) | |||
| .paginate(page=args["page"], per_page=args["limit"]) | |||
| ) | |||
| has_more = False | |||
| if len(tenants.items) == args['limit']: | |||
| if len(tenants.items) == args["limit"]: | |||
| current_page_first_tenant = tenants[-1] | |||
| rest_count = db.session.query(Tenant).filter( | |||
| Tenant.created_at < current_page_first_tenant.created_at, | |||
| Tenant.id != current_page_first_tenant.id | |||
| ).count() | |||
| rest_count = ( | |||
| db.session.query(Tenant) | |||
| .filter( | |||
| Tenant.created_at < current_page_first_tenant.created_at, Tenant.id != current_page_first_tenant.id | |||
| ) | |||
| .count() | |||
| ) | |||
| if rest_count > 0: | |||
| has_more = True | |||
| total = db.session.query(Tenant).count() | |||
| return { | |||
| 'data': marshal(tenants.items, workspace_fields), | |||
| 'has_more': has_more, | |||
| 'limit': args['limit'], | |||
| 'page': args['page'], | |||
| 'total': total | |||
| }, 200 | |||
| "data": marshal(tenants.items, workspace_fields), | |||
| "has_more": has_more, | |||
| "limit": args["limit"], | |||
| "page": args["page"], | |||
| "total": total, | |||
| }, 200 | |||
| class TenantApi(Resource): | |||
| @@ -112,8 +113,8 @@ class TenantApi(Resource): | |||
| @account_initialization_required | |||
| @marshal_with(tenant_fields) | |||
| def get(self): | |||
| if request.path == '/info': | |||
| logging.warning('Deprecated URL /info was used.') | |||
| if request.path == "/info": | |||
| logging.warning("Deprecated URL /info was used.") | |||
| tenant = current_user.current_tenant | |||
| @@ -125,7 +126,7 @@ class TenantApi(Resource): | |||
| tenant = tenants[0] | |||
| # else, raise Unauthorized | |||
| else: | |||
| raise Unauthorized('workspace is archived') | |||
| raise Unauthorized("workspace is archived") | |||
| return WorkspaceService.get_tenant_info(tenant), 200 | |||
| @@ -136,62 +137,64 @@ class SwitchWorkspaceApi(Resource): | |||
| @account_initialization_required | |||
| def post(self): | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('tenant_id', type=str, required=True, location='json') | |||
| parser.add_argument("tenant_id", type=str, required=True, location="json") | |||
| args = parser.parse_args() | |||
| # check if tenant_id is valid, 403 if not | |||
| try: | |||
| TenantService.switch_tenant(current_user, args['tenant_id']) | |||
| TenantService.switch_tenant(current_user, args["tenant_id"]) | |||
| except Exception: | |||
| raise AccountNotLinkTenantError("Account not link tenant") | |||
| new_tenant = db.session.query(Tenant).get(args['tenant_id']) # Get new tenant | |||
| new_tenant = db.session.query(Tenant).get(args["tenant_id"]) # Get new tenant | |||
| return {"result": "success", "new_tenant": marshal(WorkspaceService.get_tenant_info(new_tenant), tenant_fields)} | |||
| return {'result': 'success', 'new_tenant': marshal(WorkspaceService.get_tenant_info(new_tenant), tenant_fields)} | |||
| class CustomConfigWorkspaceApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @cloud_edition_billing_resource_check('workspace_custom') | |||
| @cloud_edition_billing_resource_check("workspace_custom") | |||
| def post(self): | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('remove_webapp_brand', type=bool, location='json') | |||
| parser.add_argument('replace_webapp_logo', type=str, location='json') | |||
| parser.add_argument("remove_webapp_brand", type=bool, location="json") | |||
| parser.add_argument("replace_webapp_logo", type=str, location="json") | |||
| args = parser.parse_args() | |||
| tenant = db.session.query(Tenant).filter(Tenant.id == current_user.current_tenant_id).one_or_404() | |||
| custom_config_dict = { | |||
| 'remove_webapp_brand': args['remove_webapp_brand'], | |||
| 'replace_webapp_logo': args['replace_webapp_logo'] if args['replace_webapp_logo'] is not None else tenant.custom_config_dict.get('replace_webapp_logo') , | |||
| "remove_webapp_brand": args["remove_webapp_brand"], | |||
| "replace_webapp_logo": args["replace_webapp_logo"] | |||
| if args["replace_webapp_logo"] is not None | |||
| else tenant.custom_config_dict.get("replace_webapp_logo"), | |||
| } | |||
| tenant.custom_config_dict = custom_config_dict | |||
| db.session.commit() | |||
| return {'result': 'success', 'tenant': marshal(WorkspaceService.get_tenant_info(tenant), tenant_fields)} | |||
| return {"result": "success", "tenant": marshal(WorkspaceService.get_tenant_info(tenant), tenant_fields)} | |||
| class WebappLogoWorkspaceApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @cloud_edition_billing_resource_check('workspace_custom') | |||
| @cloud_edition_billing_resource_check("workspace_custom") | |||
| def post(self): | |||
| # get file from request | |||
| file = request.files['file'] | |||
| file = request.files["file"] | |||
| # check file | |||
| if 'file' not in request.files: | |||
| if "file" not in request.files: | |||
| raise NoFileUploadedError() | |||
| if len(request.files) > 1: | |||
| raise TooManyFilesError() | |||
| extension = file.filename.split('.')[-1] | |||
| if extension.lower() not in ['svg', 'png']: | |||
| extension = file.filename.split(".")[-1] | |||
| if extension.lower() not in ["svg", "png"]: | |||
| raise UnsupportedFileTypeError() | |||
| try: | |||
| @@ -201,14 +204,14 @@ class WebappLogoWorkspaceApi(Resource): | |||
| raise FileTooLargeError(file_too_large_error.description) | |||
| except services.errors.file.UnsupportedFileTypeError: | |||
| raise UnsupportedFileTypeError() | |||
| return { 'id': upload_file.id }, 201 | |||
| return {"id": upload_file.id}, 201 | |||
| api.add_resource(TenantListApi, '/workspaces') # GET for getting all tenants | |||
| api.add_resource(WorkspaceListApi, '/all-workspaces') # GET for getting all tenants | |||
| api.add_resource(TenantApi, '/workspaces/current', endpoint='workspaces_current') # GET for getting current tenant info | |||
| api.add_resource(TenantApi, '/info', endpoint='info') # Deprecated | |||
| api.add_resource(SwitchWorkspaceApi, '/workspaces/switch') # POST for switching tenant | |||
| api.add_resource(CustomConfigWorkspaceApi, '/workspaces/custom-config') | |||
| api.add_resource(WebappLogoWorkspaceApi, '/workspaces/custom-config/webapp-logo/upload') | |||
| api.add_resource(TenantListApi, "/workspaces") # GET for getting all tenants | |||
| api.add_resource(WorkspaceListApi, "/all-workspaces") # GET for getting all tenants | |||
| api.add_resource(TenantApi, "/workspaces/current", endpoint="workspaces_current") # GET for getting current tenant info | |||
| api.add_resource(TenantApi, "/info", endpoint="info") # Deprecated | |||
| api.add_resource(SwitchWorkspaceApi, "/workspaces/switch") # POST for switching tenant | |||
| api.add_resource(CustomConfigWorkspaceApi, "/workspaces/custom-config") | |||
| api.add_resource(WebappLogoWorkspaceApi, "/workspaces/custom-config/webapp-logo/upload") | |||
| @@ -16,7 +16,7 @@ def account_initialization_required(view): | |||
| # check account initialization | |||
| account = current_user | |||
| if account.status == 'uninitialized': | |||
| if account.status == "uninitialized": | |||
| raise AccountNotInitializedError() | |||
| return view(*args, **kwargs) | |||
| @@ -27,7 +27,7 @@ def account_initialization_required(view): | |||
| def only_edition_cloud(view): | |||
| @wraps(view) | |||
| def decorated(*args, **kwargs): | |||
| if dify_config.EDITION != 'CLOUD': | |||
| if dify_config.EDITION != "CLOUD": | |||
| abort(404) | |||
| return view(*args, **kwargs) | |||
| @@ -38,7 +38,7 @@ def only_edition_cloud(view): | |||
| def only_edition_self_hosted(view): | |||
| @wraps(view) | |||
| def decorated(*args, **kwargs): | |||
| if dify_config.EDITION != 'SELF_HOSTED': | |||
| if dify_config.EDITION != "SELF_HOSTED": | |||
| abort(404) | |||
| return view(*args, **kwargs) | |||
| @@ -46,8 +46,9 @@ def only_edition_self_hosted(view): | |||
| return decorated | |||
| def cloud_edition_billing_resource_check(resource: str, | |||
| error_msg: str = "You have reached the limit of your subscription."): | |||
| def cloud_edition_billing_resource_check( | |||
| resource: str, error_msg: str = "You have reached the limit of your subscription." | |||
| ): | |||
| def interceptor(view): | |||
| @wraps(view) | |||
| def decorated(*args, **kwargs): | |||
| @@ -58,22 +59,22 @@ def cloud_edition_billing_resource_check(resource: str, | |||
| vector_space = features.vector_space | |||
| documents_upload_quota = features.documents_upload_quota | |||
| annotation_quota_limit = features.annotation_quota_limit | |||
| if resource == 'members' and 0 < members.limit <= members.size: | |||
| if resource == "members" and 0 < members.limit <= members.size: | |||
| abort(403, error_msg) | |||
| elif resource == 'apps' and 0 < apps.limit <= apps.size: | |||
| elif resource == "apps" and 0 < apps.limit <= apps.size: | |||
| abort(403, error_msg) | |||
| elif resource == 'vector_space' and 0 < vector_space.limit <= vector_space.size: | |||
| elif resource == "vector_space" and 0 < vector_space.limit <= vector_space.size: | |||
| abort(403, error_msg) | |||
| elif resource == 'documents' and 0 < documents_upload_quota.limit <= documents_upload_quota.size: | |||
| elif resource == "documents" and 0 < documents_upload_quota.limit <= documents_upload_quota.size: | |||
| # The api of file upload is used in the multiple places, so we need to check the source of the request from datasets | |||
| source = request.args.get('source') | |||
| if source == 'datasets': | |||
| source = request.args.get("source") | |||
| if source == "datasets": | |||
| abort(403, error_msg) | |||
| else: | |||
| return view(*args, **kwargs) | |||
| elif resource == 'workspace_custom' and not features.can_replace_logo: | |||
| elif resource == "workspace_custom" and not features.can_replace_logo: | |||
| abort(403, error_msg) | |||
| elif resource == 'annotation' and 0 < annotation_quota_limit.limit < annotation_quota_limit.size: | |||
| elif resource == "annotation" and 0 < annotation_quota_limit.limit < annotation_quota_limit.size: | |||
| abort(403, error_msg) | |||
| else: | |||
| return view(*args, **kwargs) | |||
| @@ -85,15 +86,17 @@ def cloud_edition_billing_resource_check(resource: str, | |||
| return interceptor | |||
| def cloud_edition_billing_knowledge_limit_check(resource: str, | |||
| error_msg: str = "To unlock this feature and elevate your Dify experience, please upgrade to a paid plan."): | |||
| def cloud_edition_billing_knowledge_limit_check( | |||
| resource: str, | |||
| error_msg: str = "To unlock this feature and elevate your Dify experience, please upgrade to a paid plan.", | |||
| ): | |||
| def interceptor(view): | |||
| @wraps(view) | |||
| def decorated(*args, **kwargs): | |||
| features = FeatureService.get_features(current_user.current_tenant_id) | |||
| if features.billing.enabled: | |||
| if resource == 'add_segment': | |||
| if features.billing.subscription.plan == 'sandbox': | |||
| if resource == "add_segment": | |||
| if features.billing.subscription.plan == "sandbox": | |||
| abort(403, error_msg) | |||
| else: | |||
| return view(*args, **kwargs) | |||
| @@ -112,7 +115,7 @@ def cloud_utm_record(view): | |||
| features = FeatureService.get_features(current_user.current_tenant_id) | |||
| if features.billing.enabled: | |||
| utm_info = request.cookies.get('utm_info') | |||
| utm_info = request.cookies.get("utm_info") | |||
| if utm_info: | |||
| utm_info = json.loads(utm_info) | |||
| @@ -2,7 +2,7 @@ from flask import Blueprint | |||
| from libs.external_api import ExternalApi | |||
| bp = Blueprint('files', __name__) | |||
| bp = Blueprint("files", __name__) | |||
| api = ExternalApi(bp) | |||
| @@ -13,35 +13,30 @@ class ImagePreviewApi(Resource): | |||
| def get(self, file_id): | |||
| file_id = str(file_id) | |||
| timestamp = request.args.get('timestamp') | |||
| nonce = request.args.get('nonce') | |||
| sign = request.args.get('sign') | |||
| timestamp = request.args.get("timestamp") | |||
| nonce = request.args.get("nonce") | |||
| sign = request.args.get("sign") | |||
| if not timestamp or not nonce or not sign: | |||
| return {'content': 'Invalid request.'}, 400 | |||
| return {"content": "Invalid request."}, 400 | |||
| try: | |||
| generator, mimetype = FileService.get_image_preview( | |||
| file_id, | |||
| timestamp, | |||
| nonce, | |||
| sign | |||
| ) | |||
| generator, mimetype = FileService.get_image_preview(file_id, timestamp, nonce, sign) | |||
| except services.errors.file.UnsupportedFileTypeError: | |||
| raise UnsupportedFileTypeError() | |||
| return Response(generator, mimetype=mimetype) | |||
| class WorkspaceWebappLogoApi(Resource): | |||
| def get(self, workspace_id): | |||
| workspace_id = str(workspace_id) | |||
| custom_config = TenantService.get_custom_config(workspace_id) | |||
| webapp_logo_file_id = custom_config.get('replace_webapp_logo') if custom_config is not None else None | |||
| webapp_logo_file_id = custom_config.get("replace_webapp_logo") if custom_config is not None else None | |||
| if not webapp_logo_file_id: | |||
| raise NotFound('webapp logo is not found') | |||
| raise NotFound("webapp logo is not found") | |||
| try: | |||
| generator, mimetype = FileService.get_public_image_preview( | |||
| @@ -53,11 +48,11 @@ class WorkspaceWebappLogoApi(Resource): | |||
| return Response(generator, mimetype=mimetype) | |||
| api.add_resource(ImagePreviewApi, '/files/<uuid:file_id>/image-preview') | |||
| api.add_resource(WorkspaceWebappLogoApi, '/files/workspaces/<uuid:workspace_id>/webapp-logo') | |||
| api.add_resource(ImagePreviewApi, "/files/<uuid:file_id>/image-preview") | |||
| api.add_resource(WorkspaceWebappLogoApi, "/files/workspaces/<uuid:workspace_id>/webapp-logo") | |||
| class UnsupportedFileTypeError(BaseHTTPException): | |||
| error_code = 'unsupported_file_type' | |||
| error_code = "unsupported_file_type" | |||
| description = "File type not allowed." | |||
| code = 415 | |||
| @@ -13,36 +13,39 @@ class ToolFilePreviewApi(Resource): | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('timestamp', type=str, required=True, location='args') | |||
| parser.add_argument('nonce', type=str, required=True, location='args') | |||
| parser.add_argument('sign', type=str, required=True, location='args') | |||
| parser.add_argument("timestamp", type=str, required=True, location="args") | |||
| parser.add_argument("nonce", type=str, required=True, location="args") | |||
| parser.add_argument("sign", type=str, required=True, location="args") | |||
| args = parser.parse_args() | |||
| if not ToolFileManager.verify_file(file_id=file_id, | |||
| timestamp=args['timestamp'], | |||
| nonce=args['nonce'], | |||
| sign=args['sign'], | |||
| if not ToolFileManager.verify_file( | |||
| file_id=file_id, | |||
| timestamp=args["timestamp"], | |||
| nonce=args["nonce"], | |||
| sign=args["sign"], | |||
| ): | |||
| raise Forbidden('Invalid request.') | |||
| raise Forbidden("Invalid request.") | |||
| try: | |||
| result = ToolFileManager.get_file_generator_by_tool_file_id( | |||
| file_id, | |||
| ) | |||
| if not result: | |||
| raise NotFound('file is not found') | |||
| raise NotFound("file is not found") | |||
| generator, mimetype = result | |||
| except Exception: | |||
| raise UnsupportedFileTypeError() | |||
| return Response(generator, mimetype=mimetype) | |||
| api.add_resource(ToolFilePreviewApi, '/files/tools/<uuid:file_id>.<string:extension>') | |||
| api.add_resource(ToolFilePreviewApi, "/files/tools/<uuid:file_id>.<string:extension>") | |||
| class UnsupportedFileTypeError(BaseHTTPException): | |||
| error_code = 'unsupported_file_type' | |||
| error_code = "unsupported_file_type" | |||
| description = "File type not allowed." | |||
| code = 415 | |||
| @@ -2,8 +2,7 @@ from flask import Blueprint | |||
| from libs.external_api import ExternalApi | |||
| bp = Blueprint('inner_api', __name__, url_prefix='/inner/api') | |||
| bp = Blueprint("inner_api", __name__, url_prefix="/inner/api") | |||
| api = ExternalApi(bp) | |||
| from .workspace import workspace | |||
| @@ -9,29 +9,24 @@ from services.account_service import TenantService | |||
| class EnterpriseWorkspace(Resource): | |||
| @setup_required | |||
| @inner_api_only | |||
| def post(self): | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('name', type=str, required=True, location='json') | |||
| parser.add_argument('owner_email', type=str, required=True, location='json') | |||
| parser.add_argument("name", type=str, required=True, location="json") | |||
| parser.add_argument("owner_email", type=str, required=True, location="json") | |||
| args = parser.parse_args() | |||
| account = Account.query.filter_by(email=args['owner_email']).first() | |||
| account = Account.query.filter_by(email=args["owner_email"]).first() | |||
| if account is None: | |||
| return { | |||
| 'message': 'owner account not found.' | |||
| }, 404 | |||
| return {"message": "owner account not found."}, 404 | |||
| tenant = TenantService.create_tenant(args['name']) | |||
| TenantService.create_tenant_member(tenant, account, role='owner') | |||
| tenant = TenantService.create_tenant(args["name"]) | |||
| TenantService.create_tenant_member(tenant, account, role="owner") | |||
| tenant_was_created.send(tenant) | |||
| return { | |||
| 'message': 'enterprise workspace created.' | |||
| } | |||
| return {"message": "enterprise workspace created."} | |||
| api.add_resource(EnterpriseWorkspace, '/enterprise/workspace') | |||
| api.add_resource(EnterpriseWorkspace, "/enterprise/workspace") | |||
| @@ -17,7 +17,7 @@ def inner_api_only(view): | |||
| abort(404) | |||
| # get header 'X-Inner-Api-Key' | |||
| inner_api_key = request.headers.get('X-Inner-Api-Key') | |||
| inner_api_key = request.headers.get("X-Inner-Api-Key") | |||
| if not inner_api_key or inner_api_key != dify_config.INNER_API_KEY: | |||
| abort(401) | |||
| @@ -33,29 +33,29 @@ def inner_api_user_auth(view): | |||
| return view(*args, **kwargs) | |||
| # get header 'X-Inner-Api-Key' | |||
| authorization = request.headers.get('Authorization') | |||
| authorization = request.headers.get("Authorization") | |||
| if not authorization: | |||
| return view(*args, **kwargs) | |||
| parts = authorization.split(':') | |||
| parts = authorization.split(":") | |||
| if len(parts) != 2: | |||
| return view(*args, **kwargs) | |||
| user_id, token = parts | |||
| if ' ' in user_id: | |||
| user_id = user_id.split(' ')[1] | |||
| if " " in user_id: | |||
| user_id = user_id.split(" ")[1] | |||
| inner_api_key = request.headers.get('X-Inner-Api-Key') | |||
| inner_api_key = request.headers.get("X-Inner-Api-Key") | |||
| data_to_sign = f'DIFY {user_id}' | |||
| data_to_sign = f"DIFY {user_id}" | |||
| signature = hmac_new(inner_api_key.encode('utf-8'), data_to_sign.encode('utf-8'), sha1) | |||
| signature = b64encode(signature.digest()).decode('utf-8') | |||
| signature = hmac_new(inner_api_key.encode("utf-8"), data_to_sign.encode("utf-8"), sha1) | |||
| signature = b64encode(signature.digest()).decode("utf-8") | |||
| if signature != token: | |||
| return view(*args, **kwargs) | |||
| kwargs['user'] = db.session.query(EndUser).filter(EndUser.id == user_id).first() | |||
| kwargs["user"] = db.session.query(EndUser).filter(EndUser.id == user_id).first() | |||
| return view(*args, **kwargs) | |||
| @@ -2,7 +2,7 @@ from flask import Blueprint | |||
| from libs.external_api import ExternalApi | |||
| bp = Blueprint('service_api', __name__, url_prefix='/v1') | |||
| bp = Blueprint("service_api", __name__, url_prefix="/v1") | |||
| api = ExternalApi(bp) | |||
| @@ -1,4 +1,3 @@ | |||
| from flask_restful import Resource, fields, marshal_with | |||
| from configs import dify_config | |||
| @@ -13,32 +12,30 @@ class AppParameterApi(Resource): | |||
| """Resource for app variables.""" | |||
| variable_fields = { | |||
| 'key': fields.String, | |||
| 'name': fields.String, | |||
| 'description': fields.String, | |||
| 'type': fields.String, | |||
| 'default': fields.String, | |||
| 'max_length': fields.Integer, | |||
| 'options': fields.List(fields.String) | |||
| "key": fields.String, | |||
| "name": fields.String, | |||
| "description": fields.String, | |||
| "type": fields.String, | |||
| "default": fields.String, | |||
| "max_length": fields.Integer, | |||
| "options": fields.List(fields.String), | |||
| } | |||
| system_parameters_fields = { | |||
| 'image_file_size_limit': fields.String | |||
| } | |||
| system_parameters_fields = {"image_file_size_limit": fields.String} | |||
| parameters_fields = { | |||
| 'opening_statement': fields.String, | |||
| 'suggested_questions': fields.Raw, | |||
| 'suggested_questions_after_answer': fields.Raw, | |||
| 'speech_to_text': fields.Raw, | |||
| 'text_to_speech': fields.Raw, | |||
| 'retriever_resource': fields.Raw, | |||
| 'annotation_reply': fields.Raw, | |||
| 'more_like_this': fields.Raw, | |||
| 'user_input_form': fields.Raw, | |||
| 'sensitive_word_avoidance': fields.Raw, | |||
| 'file_upload': fields.Raw, | |||
| 'system_parameters': fields.Nested(system_parameters_fields) | |||
| "opening_statement": fields.String, | |||
| "suggested_questions": fields.Raw, | |||
| "suggested_questions_after_answer": fields.Raw, | |||
| "speech_to_text": fields.Raw, | |||
| "text_to_speech": fields.Raw, | |||
| "retriever_resource": fields.Raw, | |||
| "annotation_reply": fields.Raw, | |||
| "more_like_this": fields.Raw, | |||
| "user_input_form": fields.Raw, | |||
| "sensitive_word_avoidance": fields.Raw, | |||
| "file_upload": fields.Raw, | |||
| "system_parameters": fields.Nested(system_parameters_fields), | |||
| } | |||
| @validate_app_token | |||
| @@ -56,30 +53,35 @@ class AppParameterApi(Resource): | |||
| app_model_config = app_model.app_model_config | |||
| features_dict = app_model_config.to_dict() | |||
| user_input_form = features_dict.get('user_input_form', []) | |||
| user_input_form = features_dict.get("user_input_form", []) | |||
| return { | |||
| 'opening_statement': features_dict.get('opening_statement'), | |||
| 'suggested_questions': features_dict.get('suggested_questions', []), | |||
| 'suggested_questions_after_answer': features_dict.get('suggested_questions_after_answer', | |||
| {"enabled": False}), | |||
| 'speech_to_text': features_dict.get('speech_to_text', {"enabled": False}), | |||
| 'text_to_speech': features_dict.get('text_to_speech', {"enabled": False}), | |||
| 'retriever_resource': features_dict.get('retriever_resource', {"enabled": False}), | |||
| 'annotation_reply': features_dict.get('annotation_reply', {"enabled": False}), | |||
| 'more_like_this': features_dict.get('more_like_this', {"enabled": False}), | |||
| 'user_input_form': user_input_form, | |||
| 'sensitive_word_avoidance': features_dict.get('sensitive_word_avoidance', | |||
| {"enabled": False, "type": "", "configs": []}), | |||
| 'file_upload': features_dict.get('file_upload', {"image": { | |||
| "enabled": False, | |||
| "number_limits": 3, | |||
| "detail": "high", | |||
| "transfer_methods": ["remote_url", "local_file"] | |||
| }}), | |||
| 'system_parameters': { | |||
| 'image_file_size_limit': dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT | |||
| } | |||
| "opening_statement": features_dict.get("opening_statement"), | |||
| "suggested_questions": features_dict.get("suggested_questions", []), | |||
| "suggested_questions_after_answer": features_dict.get( | |||
| "suggested_questions_after_answer", {"enabled": False} | |||
| ), | |||
| "speech_to_text": features_dict.get("speech_to_text", {"enabled": False}), | |||
| "text_to_speech": features_dict.get("text_to_speech", {"enabled": False}), | |||
| "retriever_resource": features_dict.get("retriever_resource", {"enabled": False}), | |||
| "annotation_reply": features_dict.get("annotation_reply", {"enabled": False}), | |||
| "more_like_this": features_dict.get("more_like_this", {"enabled": False}), | |||
| "user_input_form": user_input_form, | |||
| "sensitive_word_avoidance": features_dict.get( | |||
| "sensitive_word_avoidance", {"enabled": False, "type": "", "configs": []} | |||
| ), | |||
| "file_upload": features_dict.get( | |||
| "file_upload", | |||
| { | |||
| "image": { | |||
| "enabled": False, | |||
| "number_limits": 3, | |||
| "detail": "high", | |||
| "transfer_methods": ["remote_url", "local_file"], | |||
| } | |||
| }, | |||
| ), | |||
| "system_parameters": {"image_file_size_limit": dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT}, | |||
| } | |||
| @@ -89,16 +91,14 @@ class AppMetaApi(Resource): | |||
| """Get app meta""" | |||
| return AppService().get_app_meta(app_model) | |||
| class AppInfoApi(Resource): | |||
| @validate_app_token | |||
| def get(self, app_model: App): | |||
| """Get app information""" | |||
| return { | |||
| 'name':app_model.name, | |||
| 'description':app_model.description | |||
| } | |||
| return {"name": app_model.name, "description": app_model.description} | |||
| api.add_resource(AppParameterApi, '/parameters') | |||
| api.add_resource(AppMetaApi, '/meta') | |||
| api.add_resource(AppInfoApi, '/info') | |||
| api.add_resource(AppParameterApi, "/parameters") | |||
| api.add_resource(AppMetaApi, "/meta") | |||
| api.add_resource(AppInfoApi, "/info") | |||
| @@ -33,14 +33,10 @@ from services.errors.audio import ( | |||
| class AudioApi(Resource): | |||
| @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.FORM)) | |||
| def post(self, app_model: App, end_user: EndUser): | |||
| file = request.files['file'] | |||
| file = request.files["file"] | |||
| try: | |||
| response = AudioService.transcript_asr( | |||
| app_model=app_model, | |||
| file=file, | |||
| end_user=end_user | |||
| ) | |||
| response = AudioService.transcript_asr(app_model=app_model, file=file, end_user=end_user) | |||
| return response | |||
| except services.errors.app_model_config.AppModelConfigBrokenError: | |||
| @@ -74,30 +70,32 @@ class TextApi(Resource): | |||
| def post(self, app_model: App, end_user: EndUser): | |||
| try: | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('message_id', type=str, required=False, location='json') | |||
| parser.add_argument('voice', type=str, location='json') | |||
| parser.add_argument('text', type=str, location='json') | |||
| parser.add_argument('streaming', type=bool, location='json') | |||
| parser.add_argument("message_id", type=str, required=False, location="json") | |||
| parser.add_argument("voice", type=str, location="json") | |||
| parser.add_argument("text", type=str, location="json") | |||
| parser.add_argument("streaming", type=bool, location="json") | |||
| args = parser.parse_args() | |||
| message_id = args.get('message_id', None) | |||
| text = args.get('text', None) | |||
| if (app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value] | |||
| and app_model.workflow | |||
| and app_model.workflow.features_dict): | |||
| text_to_speech = app_model.workflow.features_dict.get('text_to_speech') | |||
| voice = args.get('voice') if args.get('voice') else text_to_speech.get('voice') | |||
| message_id = args.get("message_id", None) | |||
| text = args.get("text", None) | |||
| if ( | |||
| app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value] | |||
| and app_model.workflow | |||
| and app_model.workflow.features_dict | |||
| ): | |||
| text_to_speech = app_model.workflow.features_dict.get("text_to_speech") | |||
| voice = args.get("voice") if args.get("voice") else text_to_speech.get("voice") | |||
| else: | |||
| try: | |||
| voice = args.get('voice') if args.get('voice') else app_model.app_model_config.text_to_speech_dict.get('voice') | |||
| voice = ( | |||
| args.get("voice") | |||
| if args.get("voice") | |||
| else app_model.app_model_config.text_to_speech_dict.get("voice") | |||
| ) | |||
| except Exception: | |||
| voice = None | |||
| response = AudioService.transcript_tts( | |||
| app_model=app_model, | |||
| message_id=message_id, | |||
| end_user=end_user.external_user_id, | |||
| voice=voice, | |||
| text=text | |||
| app_model=app_model, message_id=message_id, end_user=end_user.external_user_id, voice=voice, text=text | |||
| ) | |||
| return response | |||
| @@ -127,5 +125,5 @@ class TextApi(Resource): | |||
| raise InternalServerError() | |||
| api.add_resource(AudioApi, '/audio-to-text') | |||
| api.add_resource(TextApi, '/text-to-audio') | |||
| api.add_resource(AudioApi, "/audio-to-text") | |||
| api.add_resource(TextApi, "/text-to-audio") | |||
| @@ -33,21 +33,21 @@ from services.app_generate_service import AppGenerateService | |||
| class CompletionApi(Resource): | |||
| @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True)) | |||
| def post(self, app_model: App, end_user: EndUser): | |||
| if app_model.mode != 'completion': | |||
| if app_model.mode != "completion": | |||
| raise AppUnavailableError() | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('inputs', type=dict, required=True, location='json') | |||
| parser.add_argument('query', type=str, location='json', default='') | |||
| parser.add_argument('files', type=list, required=False, location='json') | |||
| parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json') | |||
| parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json') | |||
| parser.add_argument("inputs", type=dict, required=True, location="json") | |||
| parser.add_argument("query", type=str, location="json", default="") | |||
| parser.add_argument("files", type=list, required=False, location="json") | |||
| parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json") | |||
| parser.add_argument("retriever_from", type=str, required=False, default="dev", location="json") | |||
| args = parser.parse_args() | |||
| streaming = args['response_mode'] == 'streaming' | |||
| streaming = args["response_mode"] == "streaming" | |||
| args['auto_generate_name'] = False | |||
| args["auto_generate_name"] = False | |||
| try: | |||
| response = AppGenerateService.generate( | |||
| @@ -84,12 +84,12 @@ class CompletionApi(Resource): | |||
| class CompletionStopApi(Resource): | |||
| @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True)) | |||
| def post(self, app_model: App, end_user: EndUser, task_id): | |||
| if app_model.mode != 'completion': | |||
| if app_model.mode != "completion": | |||
| raise AppUnavailableError() | |||
| AppQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id) | |||
| return {'result': 'success'}, 200 | |||
| return {"result": "success"}, 200 | |||
| class ChatApi(Resource): | |||
| @@ -100,25 +100,21 @@ class ChatApi(Resource): | |||
| raise NotChatAppError() | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('inputs', type=dict, required=True, location='json') | |||
| parser.add_argument('query', type=str, required=True, location='json') | |||
| parser.add_argument('files', type=list, required=False, location='json') | |||
| parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json') | |||
| parser.add_argument('conversation_id', type=uuid_value, location='json') | |||
| parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json') | |||
| parser.add_argument('auto_generate_name', type=bool, required=False, default=True, location='json') | |||
| parser.add_argument("inputs", type=dict, required=True, location="json") | |||
| parser.add_argument("query", type=str, required=True, location="json") | |||
| parser.add_argument("files", type=list, required=False, location="json") | |||
| parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json") | |||
| parser.add_argument("conversation_id", type=uuid_value, location="json") | |||
| parser.add_argument("retriever_from", type=str, required=False, default="dev", location="json") | |||
| parser.add_argument("auto_generate_name", type=bool, required=False, default=True, location="json") | |||
| args = parser.parse_args() | |||
| streaming = args['response_mode'] == 'streaming' | |||
| streaming = args["response_mode"] == "streaming" | |||
| try: | |||
| response = AppGenerateService.generate( | |||
| app_model=app_model, | |||
| user=end_user, | |||
| args=args, | |||
| invoke_from=InvokeFrom.SERVICE_API, | |||
| streaming=streaming | |||
| app_model=app_model, user=end_user, args=args, invoke_from=InvokeFrom.SERVICE_API, streaming=streaming | |||
| ) | |||
| return helper.compact_generate_response(response) | |||
| @@ -153,10 +149,10 @@ class ChatStopApi(Resource): | |||
| AppQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id) | |||
| return {'result': 'success'}, 200 | |||
| return {"result": "success"}, 200 | |||
| api.add_resource(CompletionApi, '/completion-messages') | |||
| api.add_resource(CompletionStopApi, '/completion-messages/<string:task_id>/stop') | |||
| api.add_resource(ChatApi, '/chat-messages') | |||
| api.add_resource(ChatStopApi, '/chat-messages/<string:task_id>/stop') | |||
| api.add_resource(CompletionApi, "/completion-messages") | |||
| api.add_resource(CompletionStopApi, "/completion-messages/<string:task_id>/stop") | |||
| api.add_resource(ChatApi, "/chat-messages") | |||
| api.add_resource(ChatStopApi, "/chat-messages/<string:task_id>/stop") | |||
| @@ -14,7 +14,6 @@ from services.conversation_service import ConversationService | |||
| class ConversationApi(Resource): | |||
| @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY)) | |||
| @marshal_with(conversation_infinite_scroll_pagination_fields) | |||
| def get(self, app_model: App, end_user: EndUser): | |||
| @@ -23,20 +22,26 @@ class ConversationApi(Resource): | |||
| raise NotChatAppError() | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('last_id', type=uuid_value, location='args') | |||
| parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args') | |||
| parser.add_argument('sort_by', type=str, choices=['created_at', '-created_at', 'updated_at', '-updated_at'], | |||
| required=False, default='-updated_at', location='args') | |||
| parser.add_argument("last_id", type=uuid_value, location="args") | |||
| parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") | |||
| parser.add_argument( | |||
| "sort_by", | |||
| type=str, | |||
| choices=["created_at", "-created_at", "updated_at", "-updated_at"], | |||
| required=False, | |||
| default="-updated_at", | |||
| location="args", | |||
| ) | |||
| args = parser.parse_args() | |||
| try: | |||
| return ConversationService.pagination_by_last_id( | |||
| app_model=app_model, | |||
| user=end_user, | |||
| last_id=args['last_id'], | |||
| limit=args['limit'], | |||
| last_id=args["last_id"], | |||
| limit=args["limit"], | |||
| invoke_from=InvokeFrom.SERVICE_API, | |||
| sort_by=args['sort_by'] | |||
| sort_by=args["sort_by"], | |||
| ) | |||
| except services.errors.conversation.LastConversationNotExistsError: | |||
| raise NotFound("Last Conversation Not Exists.") | |||
| @@ -56,11 +61,10 @@ class ConversationDetailApi(Resource): | |||
| ConversationService.delete(app_model, conversation_id, end_user) | |||
| except services.errors.conversation.ConversationNotExistsError: | |||
| raise NotFound("Conversation Not Exists.") | |||
| return {'result': 'success'}, 200 | |||
| return {"result": "success"}, 200 | |||
| class ConversationRenameApi(Resource): | |||
| @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON)) | |||
| @marshal_with(simple_conversation_fields) | |||
| def post(self, app_model: App, end_user: EndUser, c_id): | |||
| @@ -71,22 +75,16 @@ class ConversationRenameApi(Resource): | |||
| conversation_id = str(c_id) | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('name', type=str, required=False, location='json') | |||
| parser.add_argument('auto_generate', type=bool, required=False, default=False, location='json') | |||
| parser.add_argument("name", type=str, required=False, location="json") | |||
| parser.add_argument("auto_generate", type=bool, required=False, default=False, location="json") | |||
| args = parser.parse_args() | |||
| try: | |||
| return ConversationService.rename( | |||
| app_model, | |||
| conversation_id, | |||
| end_user, | |||
| args['name'], | |||
| args['auto_generate'] | |||
| ) | |||
| return ConversationService.rename(app_model, conversation_id, end_user, args["name"], args["auto_generate"]) | |||
| except services.errors.conversation.ConversationNotExistsError: | |||
| raise NotFound("Conversation Not Exists.") | |||
| api.add_resource(ConversationRenameApi, '/conversations/<uuid:c_id>/name', endpoint='conversation_name') | |||
| api.add_resource(ConversationApi, '/conversations') | |||
| api.add_resource(ConversationDetailApi, '/conversations/<uuid:c_id>', endpoint='conversation_detail') | |||
| api.add_resource(ConversationRenameApi, "/conversations/<uuid:c_id>/name", endpoint="conversation_name") | |||
| api.add_resource(ConversationApi, "/conversations") | |||
| api.add_resource(ConversationDetailApi, "/conversations/<uuid:c_id>", endpoint="conversation_detail") | |||
| @@ -2,104 +2,108 @@ from libs.exception import BaseHTTPException | |||
| class AppUnavailableError(BaseHTTPException): | |||
| error_code = 'app_unavailable' | |||
| error_code = "app_unavailable" | |||
| description = "App unavailable, please check your app configurations." | |||
| code = 400 | |||
| class NotCompletionAppError(BaseHTTPException): | |||
| error_code = 'not_completion_app' | |||
| error_code = "not_completion_app" | |||
| description = "Please check if your Completion app mode matches the right API route." | |||
| code = 400 | |||
| class NotChatAppError(BaseHTTPException): | |||
| error_code = 'not_chat_app' | |||
| error_code = "not_chat_app" | |||
| description = "Please check if your app mode matches the right API route." | |||
| code = 400 | |||
| class NotWorkflowAppError(BaseHTTPException): | |||
| error_code = 'not_workflow_app' | |||
| error_code = "not_workflow_app" | |||
| description = "Please check if your app mode matches the right API route." | |||
| code = 400 | |||
| class ConversationCompletedError(BaseHTTPException): | |||
| error_code = 'conversation_completed' | |||
| error_code = "conversation_completed" | |||
| description = "The conversation has ended. Please start a new conversation." | |||
| code = 400 | |||
| class ProviderNotInitializeError(BaseHTTPException): | |||
| error_code = 'provider_not_initialize' | |||
| description = "No valid model provider credentials found. " \ | |||
| "Please go to Settings -> Model Provider to complete your provider credentials." | |||
| error_code = "provider_not_initialize" | |||
| description = ( | |||
| "No valid model provider credentials found. " | |||
| "Please go to Settings -> Model Provider to complete your provider credentials." | |||
| ) | |||
| code = 400 | |||
| class ProviderQuotaExceededError(BaseHTTPException): | |||
| error_code = 'provider_quota_exceeded' | |||
| description = "Your quota for Dify Hosted OpenAI has been exhausted. " \ | |||
| "Please go to Settings -> Model Provider to complete your own provider credentials." | |||
| error_code = "provider_quota_exceeded" | |||
| description = ( | |||
| "Your quota for Dify Hosted OpenAI has been exhausted. " | |||
| "Please go to Settings -> Model Provider to complete your own provider credentials." | |||
| ) | |||
| code = 400 | |||
| class ProviderModelCurrentlyNotSupportError(BaseHTTPException): | |||
| error_code = 'model_currently_not_support' | |||
| error_code = "model_currently_not_support" | |||
| description = "Dify Hosted OpenAI trial currently not support the GPT-4 model." | |||
| code = 400 | |||
| class CompletionRequestError(BaseHTTPException): | |||
| error_code = 'completion_request_error' | |||
| error_code = "completion_request_error" | |||
| description = "Completion request failed." | |||
| code = 400 | |||
| class NoAudioUploadedError(BaseHTTPException): | |||
| error_code = 'no_audio_uploaded' | |||
| error_code = "no_audio_uploaded" | |||
| description = "Please upload your audio." | |||
| code = 400 | |||
| class AudioTooLargeError(BaseHTTPException): | |||
| error_code = 'audio_too_large' | |||
| error_code = "audio_too_large" | |||
| description = "Audio size exceeded. {message}" | |||
| code = 413 | |||
| class UnsupportedAudioTypeError(BaseHTTPException): | |||
| error_code = 'unsupported_audio_type' | |||
| error_code = "unsupported_audio_type" | |||
| description = "Audio type not allowed." | |||
| code = 415 | |||
| class ProviderNotSupportSpeechToTextError(BaseHTTPException): | |||
| error_code = 'provider_not_support_speech_to_text' | |||
| error_code = "provider_not_support_speech_to_text" | |||
| description = "Provider not support speech to text." | |||
| code = 400 | |||
| class NoFileUploadedError(BaseHTTPException): | |||
| error_code = 'no_file_uploaded' | |||
| error_code = "no_file_uploaded" | |||
| description = "Please upload your file." | |||
| code = 400 | |||
| class TooManyFilesError(BaseHTTPException): | |||
| error_code = 'too_many_files' | |||
| error_code = "too_many_files" | |||
| description = "Only one file is allowed." | |||
| code = 400 | |||
| class FileTooLargeError(BaseHTTPException): | |||
| error_code = 'file_too_large' | |||
| error_code = "file_too_large" | |||
| description = "File size exceeded. {message}" | |||
| code = 413 | |||
| class UnsupportedFileTypeError(BaseHTTPException): | |||
| error_code = 'unsupported_file_type' | |||
| error_code = "unsupported_file_type" | |||
| description = "File type not allowed." | |||
| code = 415 | |||
| @@ -16,15 +16,13 @@ from services.file_service import FileService | |||
| class FileApi(Resource): | |||
| @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.FORM)) | |||
| @marshal_with(file_fields) | |||
| def post(self, app_model: App, end_user: EndUser): | |||
| file = request.files['file'] | |||
| file = request.files["file"] | |||
| # check file | |||
| if 'file' not in request.files: | |||
| if "file" not in request.files: | |||
| raise NoFileUploadedError() | |||
| if not file.mimetype: | |||
| @@ -43,4 +41,4 @@ class FileApi(Resource): | |||
| return upload_file, 201 | |||
| api.add_resource(FileApi, '/files/upload') | |||
| api.add_resource(FileApi, "/files/upload") | |||
| @@ -17,61 +17,59 @@ from services.message_service import MessageService | |||
| class MessageListApi(Resource): | |||
| feedback_fields = { | |||
| 'rating': fields.String | |||
| } | |||
| feedback_fields = {"rating": fields.String} | |||
| retriever_resource_fields = { | |||
| 'id': fields.String, | |||
| 'message_id': fields.String, | |||
| 'position': fields.Integer, | |||
| 'dataset_id': fields.String, | |||
| 'dataset_name': fields.String, | |||
| 'document_id': fields.String, | |||
| 'document_name': fields.String, | |||
| 'data_source_type': fields.String, | |||
| 'segment_id': fields.String, | |||
| 'score': fields.Float, | |||
| 'hit_count': fields.Integer, | |||
| 'word_count': fields.Integer, | |||
| 'segment_position': fields.Integer, | |||
| 'index_node_hash': fields.String, | |||
| 'content': fields.String, | |||
| 'created_at': TimestampField | |||
| "id": fields.String, | |||
| "message_id": fields.String, | |||
| "position": fields.Integer, | |||
| "dataset_id": fields.String, | |||
| "dataset_name": fields.String, | |||
| "document_id": fields.String, | |||
| "document_name": fields.String, | |||
| "data_source_type": fields.String, | |||
| "segment_id": fields.String, | |||
| "score": fields.Float, | |||
| "hit_count": fields.Integer, | |||
| "word_count": fields.Integer, | |||
| "segment_position": fields.Integer, | |||
| "index_node_hash": fields.String, | |||
| "content": fields.String, | |||
| "created_at": TimestampField, | |||
| } | |||
| agent_thought_fields = { | |||
| 'id': fields.String, | |||
| 'chain_id': fields.String, | |||
| 'message_id': fields.String, | |||
| 'position': fields.Integer, | |||
| 'thought': fields.String, | |||
| 'tool': fields.String, | |||
| 'tool_labels': fields.Raw, | |||
| 'tool_input': fields.String, | |||
| 'created_at': TimestampField, | |||
| 'observation': fields.String, | |||
| 'message_files': fields.List(fields.String, attribute='files') | |||
| "id": fields.String, | |||
| "chain_id": fields.String, | |||
| "message_id": fields.String, | |||
| "position": fields.Integer, | |||
| "thought": fields.String, | |||
| "tool": fields.String, | |||
| "tool_labels": fields.Raw, | |||
| "tool_input": fields.String, | |||
| "created_at": TimestampField, | |||
| "observation": fields.String, | |||
| "message_files": fields.List(fields.String, attribute="files"), | |||
| } | |||
| message_fields = { | |||
| 'id': fields.String, | |||
| 'conversation_id': fields.String, | |||
| 'inputs': fields.Raw, | |||
| 'query': fields.String, | |||
| 'answer': fields.String(attribute='re_sign_file_url_answer'), | |||
| 'message_files': fields.List(fields.Nested(message_file_fields), attribute='files'), | |||
| 'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True), | |||
| 'retriever_resources': fields.List(fields.Nested(retriever_resource_fields)), | |||
| 'created_at': TimestampField, | |||
| 'agent_thoughts': fields.List(fields.Nested(agent_thought_fields)), | |||
| 'status': fields.String, | |||
| 'error': fields.String, | |||
| "id": fields.String, | |||
| "conversation_id": fields.String, | |||
| "inputs": fields.Raw, | |||
| "query": fields.String, | |||
| "answer": fields.String(attribute="re_sign_file_url_answer"), | |||
| "message_files": fields.List(fields.Nested(message_file_fields), attribute="files"), | |||
| "feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True), | |||
| "retriever_resources": fields.List(fields.Nested(retriever_resource_fields)), | |||
| "created_at": TimestampField, | |||
| "agent_thoughts": fields.List(fields.Nested(agent_thought_fields)), | |||
| "status": fields.String, | |||
| "error": fields.String, | |||
| } | |||
| message_infinite_scroll_pagination_fields = { | |||
| 'limit': fields.Integer, | |||
| 'has_more': fields.Boolean, | |||
| 'data': fields.List(fields.Nested(message_fields)) | |||
| "limit": fields.Integer, | |||
| "has_more": fields.Boolean, | |||
| "data": fields.List(fields.Nested(message_fields)), | |||
| } | |||
| @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY)) | |||
| @@ -82,14 +80,15 @@ class MessageListApi(Resource): | |||
| raise NotChatAppError() | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('conversation_id', required=True, type=uuid_value, location='args') | |||
| parser.add_argument('first_id', type=uuid_value, location='args') | |||
| parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args') | |||
| parser.add_argument("conversation_id", required=True, type=uuid_value, location="args") | |||
| parser.add_argument("first_id", type=uuid_value, location="args") | |||
| parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") | |||
| args = parser.parse_args() | |||
| try: | |||
| return MessageService.pagination_by_first_id(app_model, end_user, | |||
| args['conversation_id'], args['first_id'], args['limit']) | |||
| return MessageService.pagination_by_first_id( | |||
| app_model, end_user, args["conversation_id"], args["first_id"], args["limit"] | |||
| ) | |||
| except services.errors.conversation.ConversationNotExistsError: | |||
| raise NotFound("Conversation Not Exists.") | |||
| except services.errors.message.FirstMessageNotExistsError: | |||
| @@ -102,15 +101,15 @@ class MessageFeedbackApi(Resource): | |||
| message_id = str(message_id) | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('rating', type=str, choices=['like', 'dislike', None], location='json') | |||
| parser.add_argument("rating", type=str, choices=["like", "dislike", None], location="json") | |||
| args = parser.parse_args() | |||
| try: | |||
| MessageService.create_feedback(app_model, message_id, end_user, args['rating']) | |||
| MessageService.create_feedback(app_model, message_id, end_user, args["rating"]) | |||
| except services.errors.message.MessageNotExistsError: | |||
| raise NotFound("Message Not Exists.") | |||
| return {'result': 'success'} | |||
| return {"result": "success"} | |||
| class MessageSuggestedApi(Resource): | |||
| @@ -123,10 +122,7 @@ class MessageSuggestedApi(Resource): | |||
| try: | |||
| questions = MessageService.get_suggested_questions_after_answer( | |||
| app_model=app_model, | |||
| user=end_user, | |||
| message_id=message_id, | |||
| invoke_from=InvokeFrom.SERVICE_API | |||
| app_model=app_model, user=end_user, message_id=message_id, invoke_from=InvokeFrom.SERVICE_API | |||
| ) | |||
| except services.errors.message.MessageNotExistsError: | |||
| raise NotFound("Message Not Exists.") | |||
| @@ -136,9 +132,9 @@ class MessageSuggestedApi(Resource): | |||
| logging.exception("internal server error.") | |||
| raise InternalServerError() | |||
| return {'result': 'success', 'data': questions} | |||
| return {"result": "success", "data": questions} | |||
| api.add_resource(MessageListApi, '/messages') | |||
| api.add_resource(MessageFeedbackApi, '/messages/<uuid:message_id>/feedbacks') | |||
| api.add_resource(MessageSuggestedApi, '/messages/<uuid:message_id>/suggested') | |||
| api.add_resource(MessageListApi, "/messages") | |||
| api.add_resource(MessageFeedbackApi, "/messages/<uuid:message_id>/feedbacks") | |||
| api.add_resource(MessageSuggestedApi, "/messages/<uuid:message_id>/suggested") | |||
| @@ -30,19 +30,20 @@ from services.app_generate_service import AppGenerateService | |||
| logger = logging.getLogger(__name__) | |||
| workflow_run_fields = { | |||
| 'id': fields.String, | |||
| 'workflow_id': fields.String, | |||
| 'status': fields.String, | |||
| 'inputs': fields.Raw, | |||
| 'outputs': fields.Raw, | |||
| 'error': fields.String, | |||
| 'total_steps': fields.Integer, | |||
| 'total_tokens': fields.Integer, | |||
| 'created_at': fields.DateTime, | |||
| 'finished_at': fields.DateTime, | |||
| 'elapsed_time': fields.Float, | |||
| "id": fields.String, | |||
| "workflow_id": fields.String, | |||
| "status": fields.String, | |||
| "inputs": fields.Raw, | |||
| "outputs": fields.Raw, | |||
| "error": fields.String, | |||
| "total_steps": fields.Integer, | |||
| "total_tokens": fields.Integer, | |||
| "created_at": fields.DateTime, | |||
| "finished_at": fields.DateTime, | |||
| "elapsed_time": fields.Float, | |||
| } | |||
| class WorkflowRunDetailApi(Resource): | |||
| @validate_app_token | |||
| @marshal_with(workflow_run_fields) | |||
| @@ -56,6 +57,8 @@ class WorkflowRunDetailApi(Resource): | |||
| workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == workflow_id).first() | |||
| return workflow_run | |||
| class WorkflowRunApi(Resource): | |||
| @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True)) | |||
| def post(self, app_model: App, end_user: EndUser): | |||
| @@ -67,20 +70,16 @@ class WorkflowRunApi(Resource): | |||
| raise NotWorkflowAppError() | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('inputs', type=dict, required=True, nullable=False, location='json') | |||
| parser.add_argument('files', type=list, required=False, location='json') | |||
| parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json') | |||
| parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") | |||
| parser.add_argument("files", type=list, required=False, location="json") | |||
| parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json") | |||
| args = parser.parse_args() | |||
| streaming = args.get('response_mode') == 'streaming' | |||
| streaming = args.get("response_mode") == "streaming" | |||
| try: | |||
| response = AppGenerateService.generate( | |||
| app_model=app_model, | |||
| user=end_user, | |||
| args=args, | |||
| invoke_from=InvokeFrom.SERVICE_API, | |||
| streaming=streaming | |||
| app_model=app_model, user=end_user, args=args, invoke_from=InvokeFrom.SERVICE_API, streaming=streaming | |||
| ) | |||
| return helper.compact_generate_response(response) | |||
| @@ -111,11 +110,9 @@ class WorkflowTaskStopApi(Resource): | |||
| AppQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id) | |||
| return { | |||
| "result": "success" | |||
| } | |||
| return {"result": "success"} | |||
| api.add_resource(WorkflowRunApi, '/workflows/run') | |||
| api.add_resource(WorkflowRunDetailApi, '/workflows/run/<string:workflow_id>') | |||
| api.add_resource(WorkflowTaskStopApi, '/workflows/tasks/<string:task_id>/stop') | |||
| api.add_resource(WorkflowRunApi, "/workflows/run") | |||
| api.add_resource(WorkflowRunDetailApi, "/workflows/run/<string:workflow_id>") | |||
| api.add_resource(WorkflowTaskStopApi, "/workflows/tasks/<string:task_id>/stop") | |||
| @@ -16,7 +16,7 @@ from services.dataset_service import DatasetService | |||
| def _validate_name(name): | |||
| if not name or len(name) < 1 or len(name) > 40: | |||
| raise ValueError('Name must be between 1 to 40 characters.') | |||
| raise ValueError("Name must be between 1 to 40 characters.") | |||
| return name | |||
| @@ -26,24 +26,18 @@ class DatasetListApi(DatasetApiResource): | |||
| def get(self, tenant_id): | |||
| """Resource for getting datasets.""" | |||
| page = request.args.get('page', default=1, type=int) | |||
| limit = request.args.get('limit', default=20, type=int) | |||
| provider = request.args.get('provider', default="vendor") | |||
| search = request.args.get('keyword', default=None, type=str) | |||
| tag_ids = request.args.getlist('tag_ids') | |||
| page = request.args.get("page", default=1, type=int) | |||
| limit = request.args.get("limit", default=20, type=int) | |||
| provider = request.args.get("provider", default="vendor") | |||
| search = request.args.get("keyword", default=None, type=str) | |||
| tag_ids = request.args.getlist("tag_ids") | |||
| datasets, total = DatasetService.get_datasets(page, limit, provider, | |||
| tenant_id, current_user, search, tag_ids) | |||
| datasets, total = DatasetService.get_datasets(page, limit, provider, tenant_id, current_user, search, tag_ids) | |||
| # check embedding setting | |||
| provider_manager = ProviderManager() | |||
| configurations = provider_manager.get_configurations( | |||
| tenant_id=current_user.current_tenant_id | |||
| ) | |||
| configurations = provider_manager.get_configurations(tenant_id=current_user.current_tenant_id) | |||
| embedding_models = configurations.get_models( | |||
| model_type=ModelType.TEXT_EMBEDDING, | |||
| only_active=True | |||
| ) | |||
| embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True) | |||
| model_names = [] | |||
| for embedding_model in embedding_models: | |||
| @@ -51,50 +45,59 @@ class DatasetListApi(DatasetApiResource): | |||
| data = marshal(datasets, dataset_detail_fields) | |||
| for item in data: | |||
| if item['indexing_technique'] == 'high_quality': | |||
| if item["indexing_technique"] == "high_quality": | |||
| item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}" | |||
| if item_model in model_names: | |||
| item['embedding_available'] = True | |||
| item["embedding_available"] = True | |||
| else: | |||
| item['embedding_available'] = False | |||
| item["embedding_available"] = False | |||
| else: | |||
| item['embedding_available'] = True | |||
| response = { | |||
| 'data': data, | |||
| 'has_more': len(datasets) == limit, | |||
| 'limit': limit, | |||
| 'total': total, | |||
| 'page': page | |||
| } | |||
| item["embedding_available"] = True | |||
| response = {"data": data, "has_more": len(datasets) == limit, "limit": limit, "total": total, "page": page} | |||
| return response, 200 | |||
| def post(self, tenant_id): | |||
| """Resource for creating datasets.""" | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('name', nullable=False, required=True, | |||
| help='type is required. Name must be between 1 to 40 characters.', | |||
| type=_validate_name) | |||
| parser.add_argument('indexing_technique', type=str, location='json', | |||
| choices=Dataset.INDEXING_TECHNIQUE_LIST, | |||
| help='Invalid indexing technique.') | |||
| parser.add_argument('permission', type=str, location='json', choices=( | |||
| DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM), help='Invalid permission.', required=False, nullable=False) | |||
| parser.add_argument( | |||
| "name", | |||
| nullable=False, | |||
| required=True, | |||
| help="type is required. Name must be between 1 to 40 characters.", | |||
| type=_validate_name, | |||
| ) | |||
| parser.add_argument( | |||
| "indexing_technique", | |||
| type=str, | |||
| location="json", | |||
| choices=Dataset.INDEXING_TECHNIQUE_LIST, | |||
| help="Invalid indexing technique.", | |||
| ) | |||
| parser.add_argument( | |||
| "permission", | |||
| type=str, | |||
| location="json", | |||
| choices=(DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM), | |||
| help="Invalid permission.", | |||
| required=False, | |||
| nullable=False, | |||
| ) | |||
| args = parser.parse_args() | |||
| try: | |||
| dataset = DatasetService.create_empty_dataset( | |||
| tenant_id=tenant_id, | |||
| name=args['name'], | |||
| indexing_technique=args['indexing_technique'], | |||
| name=args["name"], | |||
| indexing_technique=args["indexing_technique"], | |||
| account=current_user, | |||
| permission=args['permission'] | |||
| permission=args["permission"], | |||
| ) | |||
| except services.errors.dataset.DatasetNameDuplicateError: | |||
| raise DatasetNameDuplicateError() | |||
| return marshal(dataset, dataset_detail_fields), 200 | |||
| class DatasetApi(DatasetApiResource): | |||
| """Resource for dataset.""" | |||
| @@ -106,7 +109,7 @@ class DatasetApi(DatasetApiResource): | |||
| dataset_id (UUID): The ID of the dataset to be deleted. | |||
| Returns: | |||
| dict: A dictionary with a key 'result' and a value 'success' | |||
| dict: A dictionary with a key 'result' and a value 'success' | |||
| if the dataset was successfully deleted. Omitted in HTTP response. | |||
| int: HTTP status code 204 indicating that the operation was successful. | |||
| @@ -118,11 +121,12 @@ class DatasetApi(DatasetApiResource): | |||
| try: | |||
| if DatasetService.delete_dataset(dataset_id_str, current_user): | |||
| return {'result': 'success'}, 204 | |||
| return {"result": "success"}, 204 | |||
| else: | |||
| raise NotFound("Dataset not found.") | |||
| except services.errors.dataset.DatasetInUseError: | |||
| raise DatasetInUseError() | |||
| api.add_resource(DatasetListApi, '/datasets') | |||
| api.add_resource(DatasetApi, '/datasets/<uuid:dataset_id>') | |||
| api.add_resource(DatasetListApi, "/datasets") | |||
| api.add_resource(DatasetApi, "/datasets/<uuid:dataset_id>") | |||
| @@ -27,47 +27,40 @@ from services.file_service import FileService | |||
| class DocumentAddByTextApi(DatasetApiResource): | |||
| """Resource for documents.""" | |||
| @cloud_edition_billing_resource_check('vector_space', 'dataset') | |||
| @cloud_edition_billing_resource_check('documents', 'dataset') | |||
| @cloud_edition_billing_resource_check("vector_space", "dataset") | |||
| @cloud_edition_billing_resource_check("documents", "dataset") | |||
| def post(self, tenant_id, dataset_id): | |||
| """Create document by text.""" | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('name', type=str, required=True, nullable=False, location='json') | |||
| parser.add_argument('text', type=str, required=True, nullable=False, location='json') | |||
| parser.add_argument('process_rule', type=dict, required=False, nullable=True, location='json') | |||
| parser.add_argument('original_document_id', type=str, required=False, location='json') | |||
| parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json') | |||
| parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False, | |||
| location='json') | |||
| parser.add_argument('indexing_technique', type=str, choices=Dataset.INDEXING_TECHNIQUE_LIST, nullable=False, | |||
| location='json') | |||
| parser.add_argument('retrieval_model', type=dict, required=False, nullable=False, | |||
| location='json') | |||
| parser.add_argument("name", type=str, required=True, nullable=False, location="json") | |||
| parser.add_argument("text", type=str, required=True, nullable=False, location="json") | |||
| parser.add_argument("process_rule", type=dict, required=False, nullable=True, location="json") | |||
| parser.add_argument("original_document_id", type=str, required=False, location="json") | |||
| parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json") | |||
| parser.add_argument( | |||
| "doc_language", type=str, default="English", required=False, nullable=False, location="json" | |||
| ) | |||
| parser.add_argument( | |||
| "indexing_technique", type=str, choices=Dataset.INDEXING_TECHNIQUE_LIST, nullable=False, location="json" | |||
| ) | |||
| parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json") | |||
| args = parser.parse_args() | |||
| dataset_id = str(dataset_id) | |||
| tenant_id = str(tenant_id) | |||
| dataset = db.session.query(Dataset).filter( | |||
| Dataset.tenant_id == tenant_id, | |||
| Dataset.id == dataset_id | |||
| ).first() | |||
| dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() | |||
| if not dataset: | |||
| raise ValueError('Dataset is not exist.') | |||
| raise ValueError("Dataset is not exist.") | |||
| if not dataset.indexing_technique and not args['indexing_technique']: | |||
| raise ValueError('indexing_technique is required.') | |||
| if not dataset.indexing_technique and not args["indexing_technique"]: | |||
| raise ValueError("indexing_technique is required.") | |||
| upload_file = FileService.upload_text(args.get('text'), args.get('name')) | |||
| upload_file = FileService.upload_text(args.get("text"), args.get("name")) | |||
| data_source = { | |||
| 'type': 'upload_file', | |||
| 'info_list': { | |||
| 'data_source_type': 'upload_file', | |||
| 'file_info_list': { | |||
| 'file_ids': [upload_file.id] | |||
| } | |||
| } | |||
| "type": "upload_file", | |||
| "info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}}, | |||
| } | |||
| args['data_source'] = data_source | |||
| args["data_source"] = data_source | |||
| # validate args | |||
| DocumentService.document_create_args_validate(args) | |||
| @@ -76,60 +69,49 @@ class DocumentAddByTextApi(DatasetApiResource): | |||
| dataset=dataset, | |||
| document_data=args, | |||
| account=current_user, | |||
| dataset_process_rule=dataset.latest_process_rule if 'process_rule' not in args else None, | |||
| created_from='api' | |||
| dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None, | |||
| created_from="api", | |||
| ) | |||
| except ProviderTokenNotInitError as ex: | |||
| raise ProviderNotInitializeError(ex.description) | |||
| document = documents[0] | |||
| documents_and_batch_fields = { | |||
| 'document': marshal(document, document_fields), | |||
| 'batch': batch | |||
| } | |||
| documents_and_batch_fields = {"document": marshal(document, document_fields), "batch": batch} | |||
| return documents_and_batch_fields, 200 | |||
| class DocumentUpdateByTextApi(DatasetApiResource): | |||
| """Resource for update documents.""" | |||
| @cloud_edition_billing_resource_check('vector_space', 'dataset') | |||
| @cloud_edition_billing_resource_check("vector_space", "dataset") | |||
| def post(self, tenant_id, dataset_id, document_id): | |||
| """Update document by text.""" | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('name', type=str, required=False, nullable=True, location='json') | |||
| parser.add_argument('text', type=str, required=False, nullable=True, location='json') | |||
| parser.add_argument('process_rule', type=dict, required=False, nullable=True, location='json') | |||
| parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json') | |||
| parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False, | |||
| location='json') | |||
| parser.add_argument('retrieval_model', type=dict, required=False, nullable=False, | |||
| location='json') | |||
| parser.add_argument("name", type=str, required=False, nullable=True, location="json") | |||
| parser.add_argument("text", type=str, required=False, nullable=True, location="json") | |||
| parser.add_argument("process_rule", type=dict, required=False, nullable=True, location="json") | |||
| parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json") | |||
| parser.add_argument( | |||
| "doc_language", type=str, default="English", required=False, nullable=False, location="json" | |||
| ) | |||
| parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json") | |||
| args = parser.parse_args() | |||
| dataset_id = str(dataset_id) | |||
| tenant_id = str(tenant_id) | |||
| dataset = db.session.query(Dataset).filter( | |||
| Dataset.tenant_id == tenant_id, | |||
| Dataset.id == dataset_id | |||
| ).first() | |||
| dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() | |||
| if not dataset: | |||
| raise ValueError('Dataset is not exist.') | |||
| raise ValueError("Dataset is not exist.") | |||
| if args['text']: | |||
| upload_file = FileService.upload_text(args.get('text'), args.get('name')) | |||
| if args["text"]: | |||
| upload_file = FileService.upload_text(args.get("text"), args.get("name")) | |||
| data_source = { | |||
| 'type': 'upload_file', | |||
| 'info_list': { | |||
| 'data_source_type': 'upload_file', | |||
| 'file_info_list': { | |||
| 'file_ids': [upload_file.id] | |||
| } | |||
| } | |||
| "type": "upload_file", | |||
| "info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}}, | |||
| } | |||
| args['data_source'] = data_source | |||
| args["data_source"] = data_source | |||
| # validate args | |||
| args['original_document_id'] = str(document_id) | |||
| args["original_document_id"] = str(document_id) | |||
| DocumentService.document_create_args_validate(args) | |||
| try: | |||
| @@ -137,65 +119,53 @@ class DocumentUpdateByTextApi(DatasetApiResource): | |||
| dataset=dataset, | |||
| document_data=args, | |||
| account=current_user, | |||
| dataset_process_rule=dataset.latest_process_rule if 'process_rule' not in args else None, | |||
| created_from='api' | |||
| dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None, | |||
| created_from="api", | |||
| ) | |||
| except ProviderTokenNotInitError as ex: | |||
| raise ProviderNotInitializeError(ex.description) | |||
| document = documents[0] | |||
| documents_and_batch_fields = { | |||
| 'document': marshal(document, document_fields), | |||
| 'batch': batch | |||
| } | |||
| documents_and_batch_fields = {"document": marshal(document, document_fields), "batch": batch} | |||
| return documents_and_batch_fields, 200 | |||
| class DocumentAddByFileApi(DatasetApiResource): | |||
| """Resource for documents.""" | |||
| @cloud_edition_billing_resource_check('vector_space', 'dataset') | |||
| @cloud_edition_billing_resource_check('documents', 'dataset') | |||
| @cloud_edition_billing_resource_check("vector_space", "dataset") | |||
| @cloud_edition_billing_resource_check("documents", "dataset") | |||
| def post(self, tenant_id, dataset_id): | |||
| """Create document by upload file.""" | |||
| args = {} | |||
| if 'data' in request.form: | |||
| args = json.loads(request.form['data']) | |||
| if 'doc_form' not in args: | |||
| args['doc_form'] = 'text_model' | |||
| if 'doc_language' not in args: | |||
| args['doc_language'] = 'English' | |||
| if "data" in request.form: | |||
| args = json.loads(request.form["data"]) | |||
| if "doc_form" not in args: | |||
| args["doc_form"] = "text_model" | |||
| if "doc_language" not in args: | |||
| args["doc_language"] = "English" | |||
| # get dataset info | |||
| dataset_id = str(dataset_id) | |||
| tenant_id = str(tenant_id) | |||
| dataset = db.session.query(Dataset).filter( | |||
| Dataset.tenant_id == tenant_id, | |||
| Dataset.id == dataset_id | |||
| ).first() | |||
| dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() | |||
| if not dataset: | |||
| raise ValueError('Dataset is not exist.') | |||
| if not dataset.indexing_technique and not args.get('indexing_technique'): | |||
| raise ValueError('indexing_technique is required.') | |||
| raise ValueError("Dataset is not exist.") | |||
| if not dataset.indexing_technique and not args.get("indexing_technique"): | |||
| raise ValueError("indexing_technique is required.") | |||
| # save file info | |||
| file = request.files['file'] | |||
| file = request.files["file"] | |||
| # check file | |||
| if 'file' not in request.files: | |||
| if "file" not in request.files: | |||
| raise NoFileUploadedError() | |||
| if len(request.files) > 1: | |||
| raise TooManyFilesError() | |||
| upload_file = FileService.upload_file(file, current_user) | |||
| data_source = { | |||
| 'type': 'upload_file', | |||
| 'info_list': { | |||
| 'file_info_list': { | |||
| 'file_ids': [upload_file.id] | |||
| } | |||
| } | |||
| } | |||
| args['data_source'] = data_source | |||
| data_source = {"type": "upload_file", "info_list": {"file_info_list": {"file_ids": [upload_file.id]}}} | |||
| args["data_source"] = data_source | |||
| # validate args | |||
| DocumentService.document_create_args_validate(args) | |||
| @@ -204,63 +174,49 @@ class DocumentAddByFileApi(DatasetApiResource): | |||
| dataset=dataset, | |||
| document_data=args, | |||
| account=dataset.created_by_account, | |||
| dataset_process_rule=dataset.latest_process_rule if 'process_rule' not in args else None, | |||
| created_from='api' | |||
| dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None, | |||
| created_from="api", | |||
| ) | |||
| except ProviderTokenNotInitError as ex: | |||
| raise ProviderNotInitializeError(ex.description) | |||
| document = documents[0] | |||
| documents_and_batch_fields = { | |||
| 'document': marshal(document, document_fields), | |||
| 'batch': batch | |||
| } | |||
| documents_and_batch_fields = {"document": marshal(document, document_fields), "batch": batch} | |||
| return documents_and_batch_fields, 200 | |||
| class DocumentUpdateByFileApi(DatasetApiResource): | |||
| """Resource for update documents.""" | |||
| @cloud_edition_billing_resource_check('vector_space', 'dataset') | |||
| @cloud_edition_billing_resource_check("vector_space", "dataset") | |||
| def post(self, tenant_id, dataset_id, document_id): | |||
| """Update document by upload file.""" | |||
| args = {} | |||
| if 'data' in request.form: | |||
| args = json.loads(request.form['data']) | |||
| if 'doc_form' not in args: | |||
| args['doc_form'] = 'text_model' | |||
| if 'doc_language' not in args: | |||
| args['doc_language'] = 'English' | |||
| if "data" in request.form: | |||
| args = json.loads(request.form["data"]) | |||
| if "doc_form" not in args: | |||
| args["doc_form"] = "text_model" | |||
| if "doc_language" not in args: | |||
| args["doc_language"] = "English" | |||
| # get dataset info | |||
| dataset_id = str(dataset_id) | |||
| tenant_id = str(tenant_id) | |||
| dataset = db.session.query(Dataset).filter( | |||
| Dataset.tenant_id == tenant_id, | |||
| Dataset.id == dataset_id | |||
| ).first() | |||
| dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() | |||
| if not dataset: | |||
| raise ValueError('Dataset is not exist.') | |||
| if 'file' in request.files: | |||
| raise ValueError("Dataset is not exist.") | |||
| if "file" in request.files: | |||
| # save file info | |||
| file = request.files['file'] | |||
| file = request.files["file"] | |||
| if len(request.files) > 1: | |||
| raise TooManyFilesError() | |||
| upload_file = FileService.upload_file(file, current_user) | |||
| data_source = { | |||
| 'type': 'upload_file', | |||
| 'info_list': { | |||
| 'file_info_list': { | |||
| 'file_ids': [upload_file.id] | |||
| } | |||
| } | |||
| } | |||
| args['data_source'] = data_source | |||
| data_source = {"type": "upload_file", "info_list": {"file_info_list": {"file_ids": [upload_file.id]}}} | |||
| args["data_source"] = data_source | |||
| # validate args | |||
| args['original_document_id'] = str(document_id) | |||
| args["original_document_id"] = str(document_id) | |||
| DocumentService.document_create_args_validate(args) | |||
| try: | |||
| @@ -268,16 +224,13 @@ class DocumentUpdateByFileApi(DatasetApiResource): | |||
| dataset=dataset, | |||
| document_data=args, | |||
| account=dataset.created_by_account, | |||
| dataset_process_rule=dataset.latest_process_rule if 'process_rule' not in args else None, | |||
| created_from='api' | |||
| dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None, | |||
| created_from="api", | |||
| ) | |||
| except ProviderTokenNotInitError as ex: | |||
| raise ProviderNotInitializeError(ex.description) | |||
| document = documents[0] | |||
| documents_and_batch_fields = { | |||
| 'document': marshal(document, document_fields), | |||
| 'batch': batch | |||
| } | |||
| documents_and_batch_fields = {"document": marshal(document, document_fields), "batch": batch} | |||
| return documents_and_batch_fields, 200 | |||
| @@ -289,13 +242,10 @@ class DocumentDeleteApi(DatasetApiResource): | |||
| tenant_id = str(tenant_id) | |||
| # get dataset info | |||
| dataset = db.session.query(Dataset).filter( | |||
| Dataset.tenant_id == tenant_id, | |||
| Dataset.id == dataset_id | |||
| ).first() | |||
| dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() | |||
| if not dataset: | |||
| raise ValueError('Dataset is not exist.') | |||
| raise ValueError("Dataset is not exist.") | |||
| document = DocumentService.get_document(dataset.id, document_id) | |||
| @@ -311,44 +261,39 @@ class DocumentDeleteApi(DatasetApiResource): | |||
| # delete document | |||
| DocumentService.delete_document(document) | |||
| except services.errors.document.DocumentIndexingError: | |||
| raise DocumentIndexingError('Cannot delete document during indexing.') | |||
| raise DocumentIndexingError("Cannot delete document during indexing.") | |||
| return {'result': 'success'}, 200 | |||
| return {"result": "success"}, 200 | |||
| class DocumentListApi(DatasetApiResource): | |||
| def get(self, tenant_id, dataset_id): | |||
| dataset_id = str(dataset_id) | |||
| tenant_id = str(tenant_id) | |||
| page = request.args.get('page', default=1, type=int) | |||
| limit = request.args.get('limit', default=20, type=int) | |||
| search = request.args.get('keyword', default=None, type=str) | |||
| dataset = db.session.query(Dataset).filter( | |||
| Dataset.tenant_id == tenant_id, | |||
| Dataset.id == dataset_id | |||
| ).first() | |||
| page = request.args.get("page", default=1, type=int) | |||
| limit = request.args.get("limit", default=20, type=int) | |||
| search = request.args.get("keyword", default=None, type=str) | |||
| dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() | |||
| if not dataset: | |||
| raise NotFound('Dataset not found.') | |||
| raise NotFound("Dataset not found.") | |||
| query = Document.query.filter_by( | |||
| dataset_id=str(dataset_id), tenant_id=tenant_id) | |||
| query = Document.query.filter_by(dataset_id=str(dataset_id), tenant_id=tenant_id) | |||
| if search: | |||
| search = f'%{search}%' | |||
| search = f"%{search}%" | |||
| query = query.filter(Document.name.like(search)) | |||
| query = query.order_by(desc(Document.created_at)) | |||
| paginated_documents = query.paginate( | |||
| page=page, per_page=limit, max_per_page=100, error_out=False) | |||
| paginated_documents = query.paginate(page=page, per_page=limit, max_per_page=100, error_out=False) | |||
| documents = paginated_documents.items | |||
| response = { | |||
| 'data': marshal(documents, document_fields), | |||
| 'has_more': len(documents) == limit, | |||
| 'limit': limit, | |||
| 'total': paginated_documents.total, | |||
| 'page': page | |||
| "data": marshal(documents, document_fields), | |||
| "has_more": len(documents) == limit, | |||
| "limit": limit, | |||
| "total": paginated_documents.total, | |||
| "page": page, | |||
| } | |||
| return response | |||
| @@ -360,38 +305,36 @@ class DocumentIndexingStatusApi(DatasetApiResource): | |||
| batch = str(batch) | |||
| tenant_id = str(tenant_id) | |||
| # get dataset | |||
| dataset = db.session.query(Dataset).filter( | |||
| Dataset.tenant_id == tenant_id, | |||
| Dataset.id == dataset_id | |||
| ).first() | |||
| dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() | |||
| if not dataset: | |||
| raise NotFound('Dataset not found.') | |||
| raise NotFound("Dataset not found.") | |||
| # get documents | |||
| documents = DocumentService.get_batch_documents(dataset_id, batch) | |||
| if not documents: | |||
| raise NotFound('Documents not found.') | |||
| raise NotFound("Documents not found.") | |||
| documents_status = [] | |||
| for document in documents: | |||
| completed_segments = DocumentSegment.query.filter(DocumentSegment.completed_at.isnot(None), | |||
| DocumentSegment.document_id == str(document.id), | |||
| DocumentSegment.status != 're_segment').count() | |||
| total_segments = DocumentSegment.query.filter(DocumentSegment.document_id == str(document.id), | |||
| DocumentSegment.status != 're_segment').count() | |||
| completed_segments = DocumentSegment.query.filter( | |||
| DocumentSegment.completed_at.isnot(None), | |||
| DocumentSegment.document_id == str(document.id), | |||
| DocumentSegment.status != "re_segment", | |||
| ).count() | |||
| total_segments = DocumentSegment.query.filter( | |||
| DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment" | |||
| ).count() | |||
| document.completed_segments = completed_segments | |||
| document.total_segments = total_segments | |||
| if document.is_paused: | |||
| document.indexing_status = 'paused' | |||
| document.indexing_status = "paused" | |||
| documents_status.append(marshal(document, document_status_fields)) | |||
| data = { | |||
| 'data': documents_status | |||
| } | |||
| data = {"data": documents_status} | |||
| return data | |||
| api.add_resource(DocumentAddByTextApi, '/datasets/<uuid:dataset_id>/document/create_by_text') | |||
| api.add_resource(DocumentAddByFileApi, '/datasets/<uuid:dataset_id>/document/create_by_file') | |||
| api.add_resource(DocumentUpdateByTextApi, '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/update_by_text') | |||
| api.add_resource(DocumentUpdateByFileApi, '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/update_by_file') | |||
| api.add_resource(DocumentDeleteApi, '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>') | |||
| api.add_resource(DocumentListApi, '/datasets/<uuid:dataset_id>/documents') | |||
| api.add_resource(DocumentIndexingStatusApi, '/datasets/<uuid:dataset_id>/documents/<string:batch>/indexing-status') | |||
| api.add_resource(DocumentAddByTextApi, "/datasets/<uuid:dataset_id>/document/create_by_text") | |||
| api.add_resource(DocumentAddByFileApi, "/datasets/<uuid:dataset_id>/document/create_by_file") | |||
| api.add_resource(DocumentUpdateByTextApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/update_by_text") | |||
| api.add_resource(DocumentUpdateByFileApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/update_by_file") | |||
| api.add_resource(DocumentDeleteApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>") | |||
| api.add_resource(DocumentListApi, "/datasets/<uuid:dataset_id>/documents") | |||
| api.add_resource(DocumentIndexingStatusApi, "/datasets/<uuid:dataset_id>/documents/<string:batch>/indexing-status") | |||
| @@ -2,78 +2,78 @@ from libs.exception import BaseHTTPException | |||
| class NoFileUploadedError(BaseHTTPException): | |||
| error_code = 'no_file_uploaded' | |||
| error_code = "no_file_uploaded" | |||
| description = "Please upload your file." | |||
| code = 400 | |||
| class TooManyFilesError(BaseHTTPException): | |||
| error_code = 'too_many_files' | |||
| error_code = "too_many_files" | |||
| description = "Only one file is allowed." | |||
| code = 400 | |||
| class FileTooLargeError(BaseHTTPException): | |||
| error_code = 'file_too_large' | |||
| error_code = "file_too_large" | |||
| description = "File size exceeded. {message}" | |||
| code = 413 | |||
| class UnsupportedFileTypeError(BaseHTTPException): | |||
| error_code = 'unsupported_file_type' | |||
| error_code = "unsupported_file_type" | |||
| description = "File type not allowed." | |||
| code = 415 | |||
| class HighQualityDatasetOnlyError(BaseHTTPException): | |||
| error_code = 'high_quality_dataset_only' | |||
| error_code = "high_quality_dataset_only" | |||
| description = "Current operation only supports 'high-quality' datasets." | |||
| code = 400 | |||
| class DatasetNotInitializedError(BaseHTTPException): | |||
| error_code = 'dataset_not_initialized' | |||
| error_code = "dataset_not_initialized" | |||
| description = "The dataset is still being initialized or indexing. Please wait a moment." | |||
| code = 400 | |||
| class ArchivedDocumentImmutableError(BaseHTTPException): | |||
| error_code = 'archived_document_immutable' | |||
| error_code = "archived_document_immutable" | |||
| description = "The archived document is not editable." | |||
| code = 403 | |||
| class DatasetNameDuplicateError(BaseHTTPException): | |||
| error_code = 'dataset_name_duplicate' | |||
| error_code = "dataset_name_duplicate" | |||
| description = "The dataset name already exists. Please modify your dataset name." | |||
| code = 409 | |||
| class InvalidActionError(BaseHTTPException): | |||
| error_code = 'invalid_action' | |||
| error_code = "invalid_action" | |||
| description = "Invalid action." | |||
| code = 400 | |||
| class DocumentAlreadyFinishedError(BaseHTTPException): | |||
| error_code = 'document_already_finished' | |||
| error_code = "document_already_finished" | |||
| description = "The document has been processed. Please refresh the page or go to the document details." | |||
| code = 400 | |||
| class DocumentIndexingError(BaseHTTPException): | |||
| error_code = 'document_indexing' | |||
| error_code = "document_indexing" | |||
| description = "The document is being processed and cannot be edited." | |||
| code = 400 | |||
| class InvalidMetadataError(BaseHTTPException): | |||
| error_code = 'invalid_metadata' | |||
| error_code = "invalid_metadata" | |||
| description = "The metadata content is incorrect. Please check and verify." | |||
| code = 400 | |||
| class DatasetInUseError(BaseHTTPException): | |||
| error_code = 'dataset_in_use' | |||
| error_code = "dataset_in_use" | |||
| description = "The dataset is being used by some apps. Please remove the dataset from the apps before deleting it." | |||
| code = 409 | |||
| @@ -21,52 +21,47 @@ from services.dataset_service import DatasetService, DocumentService, SegmentSer | |||
| class SegmentApi(DatasetApiResource): | |||
| """Resource for segments.""" | |||
| @cloud_edition_billing_resource_check('vector_space', 'dataset') | |||
| @cloud_edition_billing_knowledge_limit_check('add_segment', 'dataset') | |||
| @cloud_edition_billing_resource_check("vector_space", "dataset") | |||
| @cloud_edition_billing_knowledge_limit_check("add_segment", "dataset") | |||
| def post(self, tenant_id, dataset_id, document_id): | |||
| """Create single segment.""" | |||
| # check dataset | |||
| dataset_id = str(dataset_id) | |||
| tenant_id = str(tenant_id) | |||
| dataset = db.session.query(Dataset).filter( | |||
| Dataset.tenant_id == tenant_id, | |||
| Dataset.id == dataset_id | |||
| ).first() | |||
| dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() | |||
| if not dataset: | |||
| raise NotFound('Dataset not found.') | |||
| raise NotFound("Dataset not found.") | |||
| # check document | |||
| document_id = str(document_id) | |||
| document = DocumentService.get_document(dataset.id, document_id) | |||
| if not document: | |||
| raise NotFound('Document not found.') | |||
| raise NotFound("Document not found.") | |||
| # check embedding model setting | |||
| if dataset.indexing_technique == 'high_quality': | |||
| if dataset.indexing_technique == "high_quality": | |||
| try: | |||
| model_manager = ModelManager() | |||
| model_manager.get_model_instance( | |||
| tenant_id=current_user.current_tenant_id, | |||
| provider=dataset.embedding_model_provider, | |||
| model_type=ModelType.TEXT_EMBEDDING, | |||
| model=dataset.embedding_model | |||
| model=dataset.embedding_model, | |||
| ) | |||
| except LLMBadRequestError: | |||
| raise ProviderNotInitializeError( | |||
| "No Embedding Model available. Please configure a valid provider " | |||
| "in the Settings -> Model Provider.") | |||
| except ProviderTokenNotInitError as ex: | |||
| "in the Settings -> Model Provider." | |||
| ) | |||
| except ProviderTokenNotInitError as ex: | |||
| raise ProviderNotInitializeError(ex.description) | |||
| # validate args | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('segments', type=list, required=False, nullable=True, location='json') | |||
| parser.add_argument("segments", type=list, required=False, nullable=True, location="json") | |||
| args = parser.parse_args() | |||
| if args['segments'] is not None: | |||
| for args_item in args['segments']: | |||
| if args["segments"] is not None: | |||
| for args_item in args["segments"]: | |||
| SegmentService.segment_create_args_validate(args_item, document) | |||
| segments = SegmentService.multi_create_segment(args['segments'], document, dataset) | |||
| return { | |||
| 'data': marshal(segments, segment_fields), | |||
| 'doc_form': document.doc_form | |||
| }, 200 | |||
| segments = SegmentService.multi_create_segment(args["segments"], document, dataset) | |||
| return {"data": marshal(segments, segment_fields), "doc_form": document.doc_form}, 200 | |||
| else: | |||
| return {"error": "Segemtns is required"}, 400 | |||
| @@ -75,61 +70,53 @@ class SegmentApi(DatasetApiResource): | |||
| # check dataset | |||
| dataset_id = str(dataset_id) | |||
| tenant_id = str(tenant_id) | |||
| dataset = db.session.query(Dataset).filter( | |||
| Dataset.tenant_id == tenant_id, | |||
| Dataset.id == dataset_id | |||
| ).first() | |||
| dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() | |||
| if not dataset: | |||
| raise NotFound('Dataset not found.') | |||
| raise NotFound("Dataset not found.") | |||
| # check document | |||
| document_id = str(document_id) | |||
| document = DocumentService.get_document(dataset.id, document_id) | |||
| if not document: | |||
| raise NotFound('Document not found.') | |||
| raise NotFound("Document not found.") | |||
| # check embedding model setting | |||
| if dataset.indexing_technique == 'high_quality': | |||
| if dataset.indexing_technique == "high_quality": | |||
| try: | |||
| model_manager = ModelManager() | |||
| model_manager.get_model_instance( | |||
| tenant_id=current_user.current_tenant_id, | |||
| provider=dataset.embedding_model_provider, | |||
| model_type=ModelType.TEXT_EMBEDDING, | |||
| model=dataset.embedding_model | |||
| model=dataset.embedding_model, | |||
| ) | |||
| except LLMBadRequestError: | |||
| raise ProviderNotInitializeError( | |||
| "No Embedding Model available. Please configure a valid provider " | |||
| "in the Settings -> Model Provider.") | |||
| "in the Settings -> Model Provider." | |||
| ) | |||
| except ProviderTokenNotInitError as ex: | |||
| raise ProviderNotInitializeError(ex.description) | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('status', type=str, | |||
| action='append', default=[], location='args') | |||
| parser.add_argument('keyword', type=str, default=None, location='args') | |||
| parser.add_argument("status", type=str, action="append", default=[], location="args") | |||
| parser.add_argument("keyword", type=str, default=None, location="args") | |||
| args = parser.parse_args() | |||
| status_list = args['status'] | |||
| keyword = args['keyword'] | |||
| status_list = args["status"] | |||
| keyword = args["keyword"] | |||
| query = DocumentSegment.query.filter( | |||
| DocumentSegment.document_id == str(document_id), | |||
| DocumentSegment.tenant_id == current_user.current_tenant_id | |||
| DocumentSegment.document_id == str(document_id), DocumentSegment.tenant_id == current_user.current_tenant_id | |||
| ) | |||
| if status_list: | |||
| query = query.filter(DocumentSegment.status.in_(status_list)) | |||
| if keyword: | |||
| query = query.where(DocumentSegment.content.ilike(f'%{keyword}%')) | |||
| query = query.where(DocumentSegment.content.ilike(f"%{keyword}%")) | |||
| total = query.count() | |||
| segments = query.order_by(DocumentSegment.position).all() | |||
| return { | |||
| 'data': marshal(segments, segment_fields), | |||
| 'doc_form': document.doc_form, | |||
| 'total': total | |||
| }, 200 | |||
| return {"data": marshal(segments, segment_fields), "doc_form": document.doc_form, "total": total}, 200 | |||
| class DatasetSegmentApi(DatasetApiResource): | |||
| @@ -137,48 +124,41 @@ class DatasetSegmentApi(DatasetApiResource): | |||
| # check dataset | |||
| dataset_id = str(dataset_id) | |||
| tenant_id = str(tenant_id) | |||
| dataset = db.session.query(Dataset).filter( | |||
| Dataset.tenant_id == tenant_id, | |||
| Dataset.id == dataset_id | |||
| ).first() | |||
| dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() | |||
| if not dataset: | |||
| raise NotFound('Dataset not found.') | |||
| raise NotFound("Dataset not found.") | |||
| # check user's model setting | |||
| DatasetService.check_dataset_model_setting(dataset) | |||
| # check document | |||
| document_id = str(document_id) | |||
| document = DocumentService.get_document(dataset_id, document_id) | |||
| if not document: | |||
| raise NotFound('Document not found.') | |||
| raise NotFound("Document not found.") | |||
| # check segment | |||
| segment = DocumentSegment.query.filter( | |||
| DocumentSegment.id == str(segment_id), | |||
| DocumentSegment.tenant_id == current_user.current_tenant_id | |||
| DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id | |||
| ).first() | |||
| if not segment: | |||
| raise NotFound('Segment not found.') | |||
| raise NotFound("Segment not found.") | |||
| SegmentService.delete_segment(segment, document, dataset) | |||
| return {'result': 'success'}, 200 | |||
| return {"result": "success"}, 200 | |||
| @cloud_edition_billing_resource_check('vector_space', 'dataset') | |||
| @cloud_edition_billing_resource_check("vector_space", "dataset") | |||
| def post(self, tenant_id, dataset_id, document_id, segment_id): | |||
| # check dataset | |||
| dataset_id = str(dataset_id) | |||
| tenant_id = str(tenant_id) | |||
| dataset = db.session.query(Dataset).filter( | |||
| Dataset.tenant_id == tenant_id, | |||
| Dataset.id == dataset_id | |||
| ).first() | |||
| dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() | |||
| if not dataset: | |||
| raise NotFound('Dataset not found.') | |||
| raise NotFound("Dataset not found.") | |||
| # check user's model setting | |||
| DatasetService.check_dataset_model_setting(dataset) | |||
| # check document | |||
| document_id = str(document_id) | |||
| document = DocumentService.get_document(dataset_id, document_id) | |||
| if not document: | |||
| raise NotFound('Document not found.') | |||
| if dataset.indexing_technique == 'high_quality': | |||
| raise NotFound("Document not found.") | |||
| if dataset.indexing_technique == "high_quality": | |||
| # check embedding model setting | |||
| try: | |||
| model_manager = ModelManager() | |||
| @@ -186,35 +166,34 @@ class DatasetSegmentApi(DatasetApiResource): | |||
| tenant_id=current_user.current_tenant_id, | |||
| provider=dataset.embedding_model_provider, | |||
| model_type=ModelType.TEXT_EMBEDDING, | |||
| model=dataset.embedding_model | |||
| model=dataset.embedding_model, | |||
| ) | |||
| except LLMBadRequestError: | |||
| raise ProviderNotInitializeError( | |||
| "No Embedding Model available. Please configure a valid provider " | |||
| "in the Settings -> Model Provider.") | |||
| "in the Settings -> Model Provider." | |||
| ) | |||
| except ProviderTokenNotInitError as ex: | |||
| raise ProviderNotInitializeError(ex.description) | |||
| # check segment | |||
| segment_id = str(segment_id) | |||
| segment = DocumentSegment.query.filter( | |||
| DocumentSegment.id == str(segment_id), | |||
| DocumentSegment.tenant_id == current_user.current_tenant_id | |||
| DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id | |||
| ).first() | |||
| if not segment: | |||
| raise NotFound('Segment not found.') | |||
| raise NotFound("Segment not found.") | |||
| # validate args | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('segment', type=dict, required=False, nullable=True, location='json') | |||
| parser.add_argument("segment", type=dict, required=False, nullable=True, location="json") | |||
| args = parser.parse_args() | |||
| SegmentService.segment_create_args_validate(args['segment'], document) | |||
| segment = SegmentService.update_segment(args['segment'], segment, document, dataset) | |||
| return { | |||
| 'data': marshal(segment, segment_fields), | |||
| 'doc_form': document.doc_form | |||
| }, 200 | |||
| SegmentService.segment_create_args_validate(args["segment"], document) | |||
| segment = SegmentService.update_segment(args["segment"], segment, document, dataset) | |||
| return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200 | |||
| api.add_resource(SegmentApi, '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments') | |||
| api.add_resource(DatasetSegmentApi, '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>') | |||
| api.add_resource(SegmentApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments") | |||
| api.add_resource( | |||
| DatasetSegmentApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>" | |||
| ) | |||
| @@ -13,4 +13,4 @@ class IndexApi(Resource): | |||
| } | |||
| api.add_resource(IndexApi, '/') | |||
| api.add_resource(IndexApi, "/") | |||
| @@ -21,9 +21,10 @@ class WhereisUserArg(Enum): | |||
| """ | |||
| Enum for whereis_user_arg. | |||
| """ | |||
| QUERY = 'query' | |||
| JSON = 'json' | |||
| FORM = 'form' | |||
| QUERY = "query" | |||
| JSON = "json" | |||
| FORM = "form" | |||
| class FetchUserArg(BaseModel): | |||
| @@ -35,13 +36,13 @@ def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optio | |||
| def decorator(view_func): | |||
| @wraps(view_func) | |||
| def decorated_view(*args, **kwargs): | |||
| api_token = validate_and_get_api_token('app') | |||
| api_token = validate_and_get_api_token("app") | |||
| app_model = db.session.query(App).filter(App.id == api_token.app_id).first() | |||
| if not app_model: | |||
| raise Forbidden("The app no longer exists.") | |||
| if app_model.status != 'normal': | |||
| if app_model.status != "normal": | |||
| raise Forbidden("The app's status is abnormal.") | |||
| if not app_model.enable_api: | |||
| @@ -51,15 +52,15 @@ def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optio | |||
| if tenant.status == TenantStatus.ARCHIVE: | |||
| raise Forbidden("The workspace's status is archived.") | |||
| kwargs['app_model'] = app_model | |||
| kwargs["app_model"] = app_model | |||
| if fetch_user_arg: | |||
| if fetch_user_arg.fetch_from == WhereisUserArg.QUERY: | |||
| user_id = request.args.get('user') | |||
| user_id = request.args.get("user") | |||
| elif fetch_user_arg.fetch_from == WhereisUserArg.JSON: | |||
| user_id = request.get_json().get('user') | |||
| user_id = request.get_json().get("user") | |||
| elif fetch_user_arg.fetch_from == WhereisUserArg.FORM: | |||
| user_id = request.form.get('user') | |||
| user_id = request.form.get("user") | |||
| else: | |||
| # use default-user | |||
| user_id = None | |||
| @@ -70,9 +71,10 @@ def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optio | |||
| if user_id: | |||
| user_id = str(user_id) | |||
| kwargs['end_user'] = create_or_update_end_user_for_user_id(app_model, user_id) | |||
| kwargs["end_user"] = create_or_update_end_user_for_user_id(app_model, user_id) | |||
| return view_func(*args, **kwargs) | |||
| return decorated_view | |||
| if view is None: | |||
| @@ -81,9 +83,9 @@ def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optio | |||
| return decorator(view) | |||
| def cloud_edition_billing_resource_check(resource: str, | |||
| api_token_type: str, | |||
| error_msg: str = "You have reached the limit of your subscription."): | |||
| def cloud_edition_billing_resource_check( | |||
| resource: str, api_token_type: str, error_msg: str = "You have reached the limit of your subscription." | |||
| ): | |||
| def interceptor(view): | |||
| def decorated(*args, **kwargs): | |||
| api_token = validate_and_get_api_token(api_token_type) | |||
| @@ -95,33 +97,37 @@ def cloud_edition_billing_resource_check(resource: str, | |||
| vector_space = features.vector_space | |||
| documents_upload_quota = features.documents_upload_quota | |||
| if resource == 'members' and 0 < members.limit <= members.size: | |||
| if resource == "members" and 0 < members.limit <= members.size: | |||
| raise Forbidden(error_msg) | |||
| elif resource == 'apps' and 0 < apps.limit <= apps.size: | |||
| elif resource == "apps" and 0 < apps.limit <= apps.size: | |||
| raise Forbidden(error_msg) | |||
| elif resource == 'vector_space' and 0 < vector_space.limit <= vector_space.size: | |||
| elif resource == "vector_space" and 0 < vector_space.limit <= vector_space.size: | |||
| raise Forbidden(error_msg) | |||
| elif resource == 'documents' and 0 < documents_upload_quota.limit <= documents_upload_quota.size: | |||
| elif resource == "documents" and 0 < documents_upload_quota.limit <= documents_upload_quota.size: | |||
| raise Forbidden(error_msg) | |||
| else: | |||
| return view(*args, **kwargs) | |||
| return view(*args, **kwargs) | |||
| return decorated | |||
| return interceptor | |||
| def cloud_edition_billing_knowledge_limit_check(resource: str, | |||
| api_token_type: str, | |||
| error_msg: str = "To unlock this feature and elevate your Dify experience, please upgrade to a paid plan."): | |||
| def cloud_edition_billing_knowledge_limit_check( | |||
| resource: str, | |||
| api_token_type: str, | |||
| error_msg: str = "To unlock this feature and elevate your Dify experience, please upgrade to a paid plan.", | |||
| ): | |||
| def interceptor(view): | |||
| @wraps(view) | |||
| def decorated(*args, **kwargs): | |||
| api_token = validate_and_get_api_token(api_token_type) | |||
| features = FeatureService.get_features(api_token.tenant_id) | |||
| if features.billing.enabled: | |||
| if resource == 'add_segment': | |||
| if features.billing.subscription.plan == 'sandbox': | |||
| if resource == "add_segment": | |||
| if features.billing.subscription.plan == "sandbox": | |||
| raise Forbidden(error_msg) | |||
| else: | |||
| return view(*args, **kwargs) | |||
| @@ -132,17 +138,20 @@ def cloud_edition_billing_knowledge_limit_check(resource: str, | |||
| return interceptor | |||
| def validate_dataset_token(view=None): | |||
| def decorator(view): | |||
| @wraps(view) | |||
| def decorated(*args, **kwargs): | |||
| api_token = validate_and_get_api_token('dataset') | |||
| tenant_account_join = db.session.query(Tenant, TenantAccountJoin) \ | |||
| .filter(Tenant.id == api_token.tenant_id) \ | |||
| .filter(TenantAccountJoin.tenant_id == Tenant.id) \ | |||
| .filter(TenantAccountJoin.role.in_(['owner'])) \ | |||
| .filter(Tenant.status == TenantStatus.NORMAL) \ | |||
| .one_or_none() # TODO: only owner information is required, so only one is returned. | |||
| api_token = validate_and_get_api_token("dataset") | |||
| tenant_account_join = ( | |||
| db.session.query(Tenant, TenantAccountJoin) | |||
| .filter(Tenant.id == api_token.tenant_id) | |||
| .filter(TenantAccountJoin.tenant_id == Tenant.id) | |||
| .filter(TenantAccountJoin.role.in_(["owner"])) | |||
| .filter(Tenant.status == TenantStatus.NORMAL) | |||
| .one_or_none() | |||
| ) # TODO: only owner information is required, so only one is returned. | |||
| if tenant_account_join: | |||
| tenant, ta = tenant_account_join | |||
| account = Account.query.filter_by(id=ta.account_id).first() | |||
| @@ -156,6 +165,7 @@ def validate_dataset_token(view=None): | |||
| else: | |||
| raise Unauthorized("Tenant does not exist.") | |||
| return view(api_token.tenant_id, *args, **kwargs) | |||
| return decorated | |||
| if view: | |||
| @@ -170,20 +180,24 @@ def validate_and_get_api_token(scope=None): | |||
| """ | |||
| Validate and get API token. | |||
| """ | |||
| auth_header = request.headers.get('Authorization') | |||
| if auth_header is None or ' ' not in auth_header: | |||
| auth_header = request.headers.get("Authorization") | |||
| if auth_header is None or " " not in auth_header: | |||
| raise Unauthorized("Authorization header must be provided and start with 'Bearer'") | |||
| auth_scheme, auth_token = auth_header.split(None, 1) | |||
| auth_scheme = auth_scheme.lower() | |||
| if auth_scheme != 'bearer': | |||
| if auth_scheme != "bearer": | |||
| raise Unauthorized("Authorization scheme must be 'Bearer'") | |||
| api_token = db.session.query(ApiToken).filter( | |||
| ApiToken.token == auth_token, | |||
| ApiToken.type == scope, | |||
| ).first() | |||
| api_token = ( | |||
| db.session.query(ApiToken) | |||
| .filter( | |||
| ApiToken.token == auth_token, | |||
| ApiToken.type == scope, | |||
| ) | |||
| .first() | |||
| ) | |||
| if not api_token: | |||
| raise Unauthorized("Access token is invalid") | |||
| @@ -199,23 +213,26 @@ def create_or_update_end_user_for_user_id(app_model: App, user_id: Optional[str] | |||
| Create or update session terminal based on user ID. | |||
| """ | |||
| if not user_id: | |||
| user_id = 'DEFAULT-USER' | |||
| user_id = "DEFAULT-USER" | |||
| end_user = db.session.query(EndUser) \ | |||
| end_user = ( | |||
| db.session.query(EndUser) | |||
| .filter( | |||
| EndUser.tenant_id == app_model.tenant_id, | |||
| EndUser.app_id == app_model.id, | |||
| EndUser.session_id == user_id, | |||
| EndUser.type == 'service_api' | |||
| ).first() | |||
| EndUser.tenant_id == app_model.tenant_id, | |||
| EndUser.app_id == app_model.id, | |||
| EndUser.session_id == user_id, | |||
| EndUser.type == "service_api", | |||
| ) | |||
| .first() | |||
| ) | |||
| if end_user is None: | |||
| end_user = EndUser( | |||
| tenant_id=app_model.tenant_id, | |||
| app_id=app_model.id, | |||
| type='service_api', | |||
| is_anonymous=True if user_id == 'DEFAULT-USER' else False, | |||
| session_id=user_id | |||
| type="service_api", | |||
| is_anonymous=True if user_id == "DEFAULT-USER" else False, | |||
| session_id=user_id, | |||
| ) | |||
| db.session.add(end_user) | |||
| db.session.commit() | |||
| @@ -2,7 +2,7 @@ from flask import Blueprint | |||
| from libs.external_api import ExternalApi | |||
| bp = Blueprint('web', __name__, url_prefix='/api') | |||
| bp = Blueprint("web", __name__, url_prefix="/api") | |||
| api = ExternalApi(bp) | |||
| @@ -10,33 +10,32 @@ from services.app_service import AppService | |||
| class AppParameterApi(WebApiResource): | |||
| """Resource for app variables.""" | |||
| variable_fields = { | |||
| 'key': fields.String, | |||
| 'name': fields.String, | |||
| 'description': fields.String, | |||
| 'type': fields.String, | |||
| 'default': fields.String, | |||
| 'max_length': fields.Integer, | |||
| 'options': fields.List(fields.String) | |||
| "key": fields.String, | |||
| "name": fields.String, | |||
| "description": fields.String, | |||
| "type": fields.String, | |||
| "default": fields.String, | |||
| "max_length": fields.Integer, | |||
| "options": fields.List(fields.String), | |||
| } | |||
| system_parameters_fields = { | |||
| 'image_file_size_limit': fields.String | |||
| } | |||
| system_parameters_fields = {"image_file_size_limit": fields.String} | |||
| parameters_fields = { | |||
| 'opening_statement': fields.String, | |||
| 'suggested_questions': fields.Raw, | |||
| 'suggested_questions_after_answer': fields.Raw, | |||
| 'speech_to_text': fields.Raw, | |||
| 'text_to_speech': fields.Raw, | |||
| 'retriever_resource': fields.Raw, | |||
| 'annotation_reply': fields.Raw, | |||
| 'more_like_this': fields.Raw, | |||
| 'user_input_form': fields.Raw, | |||
| 'sensitive_word_avoidance': fields.Raw, | |||
| 'file_upload': fields.Raw, | |||
| 'system_parameters': fields.Nested(system_parameters_fields) | |||
| "opening_statement": fields.String, | |||
| "suggested_questions": fields.Raw, | |||
| "suggested_questions_after_answer": fields.Raw, | |||
| "speech_to_text": fields.Raw, | |||
| "text_to_speech": fields.Raw, | |||
| "retriever_resource": fields.Raw, | |||
| "annotation_reply": fields.Raw, | |||
| "more_like_this": fields.Raw, | |||
| "user_input_form": fields.Raw, | |||
| "sensitive_word_avoidance": fields.Raw, | |||
| "file_upload": fields.Raw, | |||
| "system_parameters": fields.Nested(system_parameters_fields), | |||
| } | |||
| @marshal_with(parameters_fields) | |||
| @@ -53,30 +52,35 @@ class AppParameterApi(WebApiResource): | |||
| app_model_config = app_model.app_model_config | |||
| features_dict = app_model_config.to_dict() | |||
| user_input_form = features_dict.get('user_input_form', []) | |||
| user_input_form = features_dict.get("user_input_form", []) | |||
| return { | |||
| 'opening_statement': features_dict.get('opening_statement'), | |||
| 'suggested_questions': features_dict.get('suggested_questions', []), | |||
| 'suggested_questions_after_answer': features_dict.get('suggested_questions_after_answer', | |||
| {"enabled": False}), | |||
| 'speech_to_text': features_dict.get('speech_to_text', {"enabled": False}), | |||
| 'text_to_speech': features_dict.get('text_to_speech', {"enabled": False}), | |||
| 'retriever_resource': features_dict.get('retriever_resource', {"enabled": False}), | |||
| 'annotation_reply': features_dict.get('annotation_reply', {"enabled": False}), | |||
| 'more_like_this': features_dict.get('more_like_this', {"enabled": False}), | |||
| 'user_input_form': user_input_form, | |||
| 'sensitive_word_avoidance': features_dict.get('sensitive_word_avoidance', | |||
| {"enabled": False, "type": "", "configs": []}), | |||
| 'file_upload': features_dict.get('file_upload', {"image": { | |||
| "enabled": False, | |||
| "number_limits": 3, | |||
| "detail": "high", | |||
| "transfer_methods": ["remote_url", "local_file"] | |||
| }}), | |||
| 'system_parameters': { | |||
| 'image_file_size_limit': dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT | |||
| } | |||
| "opening_statement": features_dict.get("opening_statement"), | |||
| "suggested_questions": features_dict.get("suggested_questions", []), | |||
| "suggested_questions_after_answer": features_dict.get( | |||
| "suggested_questions_after_answer", {"enabled": False} | |||
| ), | |||
| "speech_to_text": features_dict.get("speech_to_text", {"enabled": False}), | |||
| "text_to_speech": features_dict.get("text_to_speech", {"enabled": False}), | |||
| "retriever_resource": features_dict.get("retriever_resource", {"enabled": False}), | |||
| "annotation_reply": features_dict.get("annotation_reply", {"enabled": False}), | |||
| "more_like_this": features_dict.get("more_like_this", {"enabled": False}), | |||
| "user_input_form": user_input_form, | |||
| "sensitive_word_avoidance": features_dict.get( | |||
| "sensitive_word_avoidance", {"enabled": False, "type": "", "configs": []} | |||
| ), | |||
| "file_upload": features_dict.get( | |||
| "file_upload", | |||
| { | |||
| "image": { | |||
| "enabled": False, | |||
| "number_limits": 3, | |||
| "detail": "high", | |||
| "transfer_methods": ["remote_url", "local_file"], | |||
| } | |||
| }, | |||
| ), | |||
| "system_parameters": {"image_file_size_limit": dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT}, | |||
| } | |||
| @@ -86,5 +90,5 @@ class AppMeta(WebApiResource): | |||
| return AppService().get_app_meta(app_model) | |||
| api.add_resource(AppParameterApi, '/parameters') | |||
| api.add_resource(AppMeta, '/meta') | |||
| api.add_resource(AppParameterApi, "/parameters") | |||
| api.add_resource(AppMeta, "/meta") | |||
| @@ -31,14 +31,10 @@ from services.errors.audio import ( | |||
| class AudioApi(WebApiResource): | |||
| def post(self, app_model: App, end_user): | |||
| file = request.files['file'] | |||
| file = request.files["file"] | |||
| try: | |||
| response = AudioService.transcript_asr( | |||
| app_model=app_model, | |||
| file=file, | |||
| end_user=end_user | |||
| ) | |||
| response = AudioService.transcript_asr(app_model=app_model, file=file, end_user=end_user) | |||
| return response | |||
| except services.errors.app_model_config.AppModelConfigBrokenError: | |||
| @@ -70,34 +66,36 @@ class AudioApi(WebApiResource): | |||
| class TextApi(WebApiResource): | |||
| def post(self, app_model: App, end_user): | |||
| from flask_restful import reqparse | |||
| try: | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('message_id', type=str, required=False, location='json') | |||
| parser.add_argument('voice', type=str, location='json') | |||
| parser.add_argument('text', type=str, location='json') | |||
| parser.add_argument('streaming', type=bool, location='json') | |||
| parser.add_argument("message_id", type=str, required=False, location="json") | |||
| parser.add_argument("voice", type=str, location="json") | |||
| parser.add_argument("text", type=str, location="json") | |||
| parser.add_argument("streaming", type=bool, location="json") | |||
| args = parser.parse_args() | |||
| message_id = args.get('message_id', None) | |||
| text = args.get('text', None) | |||
| if (app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value] | |||
| and app_model.workflow | |||
| and app_model.workflow.features_dict): | |||
| text_to_speech = app_model.workflow.features_dict.get('text_to_speech') | |||
| voice = args.get('voice') if args.get('voice') else text_to_speech.get('voice') | |||
| message_id = args.get("message_id", None) | |||
| text = args.get("text", None) | |||
| if ( | |||
| app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value] | |||
| and app_model.workflow | |||
| and app_model.workflow.features_dict | |||
| ): | |||
| text_to_speech = app_model.workflow.features_dict.get("text_to_speech") | |||
| voice = args.get("voice") if args.get("voice") else text_to_speech.get("voice") | |||
| else: | |||
| try: | |||
| voice = args.get('voice') if args.get( | |||
| 'voice') else app_model.app_model_config.text_to_speech_dict.get('voice') | |||
| voice = ( | |||
| args.get("voice") | |||
| if args.get("voice") | |||
| else app_model.app_model_config.text_to_speech_dict.get("voice") | |||
| ) | |||
| except Exception: | |||
| voice = None | |||
| response = AudioService.transcript_tts( | |||
| app_model=app_model, | |||
| message_id=message_id, | |||
| end_user=end_user.external_user_id, | |||
| voice=voice, | |||
| text=text | |||
| app_model=app_model, message_id=message_id, end_user=end_user.external_user_id, voice=voice, text=text | |||
| ) | |||
| return response | |||
| @@ -127,5 +125,5 @@ class TextApi(WebApiResource): | |||
| raise InternalServerError() | |||
| api.add_resource(AudioApi, '/audio-to-text') | |||
| api.add_resource(TextApi, '/text-to-audio') | |||
| api.add_resource(AudioApi, "/audio-to-text") | |||
| api.add_resource(TextApi, "/text-to-audio") | |||
| @@ -28,30 +28,25 @@ from services.app_generate_service import AppGenerateService | |||
| # define completion api for user | |||
| class CompletionApi(WebApiResource): | |||
| def post(self, app_model, end_user): | |||
| if app_model.mode != 'completion': | |||
| if app_model.mode != "completion": | |||
| raise NotCompletionAppError() | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('inputs', type=dict, required=True, location='json') | |||
| parser.add_argument('query', type=str, location='json', default='') | |||
| parser.add_argument('files', type=list, required=False, location='json') | |||
| parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json') | |||
| parser.add_argument('retriever_from', type=str, required=False, default='web_app', location='json') | |||
| parser.add_argument("inputs", type=dict, required=True, location="json") | |||
| parser.add_argument("query", type=str, location="json", default="") | |||
| parser.add_argument("files", type=list, required=False, location="json") | |||
| parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json") | |||
| parser.add_argument("retriever_from", type=str, required=False, default="web_app", location="json") | |||
| args = parser.parse_args() | |||
| streaming = args['response_mode'] == 'streaming' | |||
| args['auto_generate_name'] = False | |||
| streaming = args["response_mode"] == "streaming" | |||
| args["auto_generate_name"] = False | |||
| try: | |||
| response = AppGenerateService.generate( | |||
| app_model=app_model, | |||
| user=end_user, | |||
| args=args, | |||
| invoke_from=InvokeFrom.WEB_APP, | |||
| streaming=streaming | |||
| app_model=app_model, user=end_user, args=args, invoke_from=InvokeFrom.WEB_APP, streaming=streaming | |||
| ) | |||
| return helper.compact_generate_response(response) | |||
| @@ -79,12 +74,12 @@ class CompletionApi(WebApiResource): | |||
| class CompletionStopApi(WebApiResource): | |||
| def post(self, app_model, end_user, task_id): | |||
| if app_model.mode != 'completion': | |||
| if app_model.mode != "completion": | |||
| raise NotCompletionAppError() | |||
| AppQueueManager.set_stop_flag(task_id, InvokeFrom.WEB_APP, end_user.id) | |||
| return {'result': 'success'}, 200 | |||
| return {"result": "success"}, 200 | |||
| class ChatApi(WebApiResource): | |||
| @@ -94,25 +89,21 @@ class ChatApi(WebApiResource): | |||
| raise NotChatAppError() | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('inputs', type=dict, required=True, location='json') | |||
| parser.add_argument('query', type=str, required=True, location='json') | |||
| parser.add_argument('files', type=list, required=False, location='json') | |||
| parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json') | |||
| parser.add_argument('conversation_id', type=uuid_value, location='json') | |||
| parser.add_argument('retriever_from', type=str, required=False, default='web_app', location='json') | |||
| parser.add_argument("inputs", type=dict, required=True, location="json") | |||
| parser.add_argument("query", type=str, required=True, location="json") | |||
| parser.add_argument("files", type=list, required=False, location="json") | |||
| parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json") | |||
| parser.add_argument("conversation_id", type=uuid_value, location="json") | |||
| parser.add_argument("retriever_from", type=str, required=False, default="web_app", location="json") | |||
| args = parser.parse_args() | |||
| streaming = args['response_mode'] == 'streaming' | |||
| args['auto_generate_name'] = False | |||
| streaming = args["response_mode"] == "streaming" | |||
| args["auto_generate_name"] = False | |||
| try: | |||
| response = AppGenerateService.generate( | |||
| app_model=app_model, | |||
| user=end_user, | |||
| args=args, | |||
| invoke_from=InvokeFrom.WEB_APP, | |||
| streaming=streaming | |||
| app_model=app_model, user=end_user, args=args, invoke_from=InvokeFrom.WEB_APP, streaming=streaming | |||
| ) | |||
| return helper.compact_generate_response(response) | |||
| @@ -146,10 +137,10 @@ class ChatStopApi(WebApiResource): | |||
| AppQueueManager.set_stop_flag(task_id, InvokeFrom.WEB_APP, end_user.id) | |||
| return {'result': 'success'}, 200 | |||
| return {"result": "success"}, 200 | |||
| api.add_resource(CompletionApi, '/completion-messages') | |||
| api.add_resource(CompletionStopApi, '/completion-messages/<string:task_id>/stop') | |||
| api.add_resource(ChatApi, '/chat-messages') | |||
| api.add_resource(ChatStopApi, '/chat-messages/<string:task_id>/stop') | |||
| api.add_resource(CompletionApi, "/completion-messages") | |||
| api.add_resource(CompletionStopApi, "/completion-messages/<string:task_id>/stop") | |||
| api.add_resource(ChatApi, "/chat-messages") | |||
| api.add_resource(ChatStopApi, "/chat-messages/<string:task_id>/stop") | |||
| @@ -15,7 +15,6 @@ from services.web_conversation_service import WebConversationService | |||
| class ConversationListApi(WebApiResource): | |||
| @marshal_with(conversation_infinite_scroll_pagination_fields) | |||
| def get(self, app_model, end_user): | |||
| app_mode = AppMode.value_of(app_model.mode) | |||
| @@ -23,26 +22,32 @@ class ConversationListApi(WebApiResource): | |||
| raise NotChatAppError() | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('last_id', type=uuid_value, location='args') | |||
| parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args') | |||
| parser.add_argument('pinned', type=str, choices=['true', 'false', None], location='args') | |||
| parser.add_argument('sort_by', type=str, choices=['created_at', '-created_at', 'updated_at', '-updated_at'], | |||
| required=False, default='-updated_at', location='args') | |||
| parser.add_argument("last_id", type=uuid_value, location="args") | |||
| parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") | |||
| parser.add_argument("pinned", type=str, choices=["true", "false", None], location="args") | |||
| parser.add_argument( | |||
| "sort_by", | |||
| type=str, | |||
| choices=["created_at", "-created_at", "updated_at", "-updated_at"], | |||
| required=False, | |||
| default="-updated_at", | |||
| location="args", | |||
| ) | |||
| args = parser.parse_args() | |||
| pinned = None | |||
| if 'pinned' in args and args['pinned'] is not None: | |||
| pinned = True if args['pinned'] == 'true' else False | |||
| if "pinned" in args and args["pinned"] is not None: | |||
| pinned = True if args["pinned"] == "true" else False | |||
| try: | |||
| return WebConversationService.pagination_by_last_id( | |||
| app_model=app_model, | |||
| user=end_user, | |||
| last_id=args['last_id'], | |||
| limit=args['limit'], | |||
| last_id=args["last_id"], | |||
| limit=args["limit"], | |||
| invoke_from=InvokeFrom.WEB_APP, | |||
| pinned=pinned, | |||
| sort_by=args['sort_by'] | |||
| sort_by=args["sort_by"], | |||
| ) | |||
| except LastConversationNotExistsError: | |||
| raise NotFound("Last Conversation Not Exists.") | |||
| @@ -65,7 +70,6 @@ class ConversationApi(WebApiResource): | |||
| class ConversationRenameApi(WebApiResource): | |||
| @marshal_with(simple_conversation_fields) | |||
| def post(self, app_model, end_user, c_id): | |||
| app_mode = AppMode.value_of(app_model.mode) | |||
| @@ -75,24 +79,17 @@ class ConversationRenameApi(WebApiResource): | |||
| conversation_id = str(c_id) | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('name', type=str, required=False, location='json') | |||
| parser.add_argument('auto_generate', type=bool, required=False, default=False, location='json') | |||
| parser.add_argument("name", type=str, required=False, location="json") | |||
| parser.add_argument("auto_generate", type=bool, required=False, default=False, location="json") | |||
| args = parser.parse_args() | |||
| try: | |||
| return ConversationService.rename( | |||
| app_model, | |||
| conversation_id, | |||
| end_user, | |||
| args['name'], | |||
| args['auto_generate'] | |||
| ) | |||
| return ConversationService.rename(app_model, conversation_id, end_user, args["name"], args["auto_generate"]) | |||
| except ConversationNotExistsError: | |||
| raise NotFound("Conversation Not Exists.") | |||
| class ConversationPinApi(WebApiResource): | |||
| def patch(self, app_model, end_user, c_id): | |||
| app_mode = AppMode.value_of(app_model.mode) | |||
| if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: | |||
| @@ -120,8 +117,8 @@ class ConversationUnPinApi(WebApiResource): | |||
| return {"result": "success"} | |||
| api.add_resource(ConversationRenameApi, '/conversations/<uuid:c_id>/name', endpoint='web_conversation_name') | |||
| api.add_resource(ConversationListApi, '/conversations') | |||
| api.add_resource(ConversationApi, '/conversations/<uuid:c_id>') | |||
| api.add_resource(ConversationPinApi, '/conversations/<uuid:c_id>/pin') | |||
| api.add_resource(ConversationUnPinApi, '/conversations/<uuid:c_id>/unpin') | |||
| api.add_resource(ConversationRenameApi, "/conversations/<uuid:c_id>/name", endpoint="web_conversation_name") | |||
| api.add_resource(ConversationListApi, "/conversations") | |||
| api.add_resource(ConversationApi, "/conversations/<uuid:c_id>") | |||
| api.add_resource(ConversationPinApi, "/conversations/<uuid:c_id>/pin") | |||
| api.add_resource(ConversationUnPinApi, "/conversations/<uuid:c_id>/unpin") | |||
| @@ -2,122 +2,126 @@ from libs.exception import BaseHTTPException | |||
| class AppUnavailableError(BaseHTTPException): | |||
| error_code = 'app_unavailable' | |||
| error_code = "app_unavailable" | |||
| description = "App unavailable, please check your app configurations." | |||
| code = 400 | |||
| class NotCompletionAppError(BaseHTTPException): | |||
| error_code = 'not_completion_app' | |||
| error_code = "not_completion_app" | |||
| description = "Please check if your Completion app mode matches the right API route." | |||
| code = 400 | |||
| class NotChatAppError(BaseHTTPException): | |||
| error_code = 'not_chat_app' | |||
| error_code = "not_chat_app" | |||
| description = "Please check if your app mode matches the right API route." | |||
| code = 400 | |||
| class NotWorkflowAppError(BaseHTTPException): | |||
| error_code = 'not_workflow_app' | |||
| error_code = "not_workflow_app" | |||
| description = "Please check if your Workflow app mode matches the right API route." | |||
| code = 400 | |||
| class ConversationCompletedError(BaseHTTPException): | |||
| error_code = 'conversation_completed' | |||
| error_code = "conversation_completed" | |||
| description = "The conversation has ended. Please start a new conversation." | |||
| code = 400 | |||
| class ProviderNotInitializeError(BaseHTTPException): | |||
| error_code = 'provider_not_initialize' | |||
| description = "No valid model provider credentials found. " \ | |||
| "Please go to Settings -> Model Provider to complete your provider credentials." | |||
| error_code = "provider_not_initialize" | |||
| description = ( | |||
| "No valid model provider credentials found. " | |||
| "Please go to Settings -> Model Provider to complete your provider credentials." | |||
| ) | |||
| code = 400 | |||
| class ProviderQuotaExceededError(BaseHTTPException): | |||
| error_code = 'provider_quota_exceeded' | |||
| description = "Your quota for Dify Hosted OpenAI has been exhausted. " \ | |||
| "Please go to Settings -> Model Provider to complete your own provider credentials." | |||
| error_code = "provider_quota_exceeded" | |||
| description = ( | |||
| "Your quota for Dify Hosted OpenAI has been exhausted. " | |||
| "Please go to Settings -> Model Provider to complete your own provider credentials." | |||
| ) | |||
| code = 400 | |||
| class ProviderModelCurrentlyNotSupportError(BaseHTTPException): | |||
| error_code = 'model_currently_not_support' | |||
| error_code = "model_currently_not_support" | |||
| description = "Dify Hosted OpenAI trial currently not support the GPT-4 model." | |||
| code = 400 | |||
| class CompletionRequestError(BaseHTTPException): | |||
| error_code = 'completion_request_error' | |||
| error_code = "completion_request_error" | |||
| description = "Completion request failed." | |||
| code = 400 | |||
| class AppMoreLikeThisDisabledError(BaseHTTPException): | |||
| error_code = 'app_more_like_this_disabled' | |||
| error_code = "app_more_like_this_disabled" | |||
| description = "The 'More like this' feature is disabled. Please refresh your page." | |||
| code = 403 | |||
| class AppSuggestedQuestionsAfterAnswerDisabledError(BaseHTTPException): | |||
| error_code = 'app_suggested_questions_after_answer_disabled' | |||
| error_code = "app_suggested_questions_after_answer_disabled" | |||
| description = "The 'Suggested Questions After Answer' feature is disabled. Please refresh your page." | |||
| code = 403 | |||
| class NoAudioUploadedError(BaseHTTPException): | |||
| error_code = 'no_audio_uploaded' | |||
| error_code = "no_audio_uploaded" | |||
| description = "Please upload your audio." | |||
| code = 400 | |||
| class AudioTooLargeError(BaseHTTPException): | |||
| error_code = 'audio_too_large' | |||
| error_code = "audio_too_large" | |||
| description = "Audio size exceeded. {message}" | |||
| code = 413 | |||
| class UnsupportedAudioTypeError(BaseHTTPException): | |||
| error_code = 'unsupported_audio_type' | |||
| error_code = "unsupported_audio_type" | |||
| description = "Audio type not allowed." | |||
| code = 415 | |||
| class ProviderNotSupportSpeechToTextError(BaseHTTPException): | |||
| error_code = 'provider_not_support_speech_to_text' | |||
| error_code = "provider_not_support_speech_to_text" | |||
| description = "Provider not support speech to text." | |||
| code = 400 | |||
| class NoFileUploadedError(BaseHTTPException): | |||
| error_code = 'no_file_uploaded' | |||
| error_code = "no_file_uploaded" | |||
| description = "Please upload your file." | |||
| code = 400 | |||
| class TooManyFilesError(BaseHTTPException): | |||
| error_code = 'too_many_files' | |||
| error_code = "too_many_files" | |||
| description = "Only one file is allowed." | |||
| code = 400 | |||
| class FileTooLargeError(BaseHTTPException): | |||
| error_code = 'file_too_large' | |||
| error_code = "file_too_large" | |||
| description = "File size exceeded. {message}" | |||
| code = 413 | |||
| class UnsupportedFileTypeError(BaseHTTPException): | |||
| error_code = 'unsupported_file_type' | |||
| error_code = "unsupported_file_type" | |||
| description = "File type not allowed." | |||
| code = 415 | |||
| class WebSSOAuthRequiredError(BaseHTTPException): | |||
| error_code = 'web_sso_auth_required' | |||
| error_code = "web_sso_auth_required" | |||
| description = "Web SSO authentication required." | |||
| code = 401 | |||
| @@ -9,4 +9,4 @@ class SystemFeatureApi(Resource): | |||
| return FeatureService.get_system_features().model_dump() | |||
| api.add_resource(SystemFeatureApi, '/system-features') | |||
| api.add_resource(SystemFeatureApi, "/system-features") | |||
| @@ -10,14 +10,13 @@ from services.file_service import FileService | |||
| class FileApi(WebApiResource): | |||
| @marshal_with(file_fields) | |||
| def post(self, app_model, end_user): | |||
| # get file from request | |||
| file = request.files['file'] | |||
| file = request.files["file"] | |||
| # check file | |||
| if 'file' not in request.files: | |||
| if "file" not in request.files: | |||
| raise NoFileUploadedError() | |||
| if len(request.files) > 1: | |||
| @@ -32,4 +31,4 @@ class FileApi(WebApiResource): | |||
| return upload_file, 201 | |||
| api.add_resource(FileApi, '/files/upload') | |||
| api.add_resource(FileApi, "/files/upload") | |||
| @@ -33,48 +33,46 @@ from services.message_service import MessageService | |||
| class MessageListApi(WebApiResource): | |||
| feedback_fields = { | |||
| 'rating': fields.String | |||
| } | |||
| feedback_fields = {"rating": fields.String} | |||
| retriever_resource_fields = { | |||
| 'id': fields.String, | |||
| 'message_id': fields.String, | |||
| 'position': fields.Integer, | |||
| 'dataset_id': fields.String, | |||
| 'dataset_name': fields.String, | |||
| 'document_id': fields.String, | |||
| 'document_name': fields.String, | |||
| 'data_source_type': fields.String, | |||
| 'segment_id': fields.String, | |||
| 'score': fields.Float, | |||
| 'hit_count': fields.Integer, | |||
| 'word_count': fields.Integer, | |||
| 'segment_position': fields.Integer, | |||
| 'index_node_hash': fields.String, | |||
| 'content': fields.String, | |||
| 'created_at': TimestampField | |||
| "id": fields.String, | |||
| "message_id": fields.String, | |||
| "position": fields.Integer, | |||
| "dataset_id": fields.String, | |||
| "dataset_name": fields.String, | |||
| "document_id": fields.String, | |||
| "document_name": fields.String, | |||
| "data_source_type": fields.String, | |||
| "segment_id": fields.String, | |||
| "score": fields.Float, | |||
| "hit_count": fields.Integer, | |||
| "word_count": fields.Integer, | |||
| "segment_position": fields.Integer, | |||
| "index_node_hash": fields.String, | |||
| "content": fields.String, | |||
| "created_at": TimestampField, | |||
| } | |||
| message_fields = { | |||
| 'id': fields.String, | |||
| 'conversation_id': fields.String, | |||
| 'inputs': fields.Raw, | |||
| 'query': fields.String, | |||
| 'answer': fields.String(attribute='re_sign_file_url_answer'), | |||
| 'message_files': fields.List(fields.Nested(message_file_fields), attribute='files'), | |||
| 'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True), | |||
| 'retriever_resources': fields.List(fields.Nested(retriever_resource_fields)), | |||
| 'created_at': TimestampField, | |||
| 'agent_thoughts': fields.List(fields.Nested(agent_thought_fields)), | |||
| 'status': fields.String, | |||
| 'error': fields.String, | |||
| "id": fields.String, | |||
| "conversation_id": fields.String, | |||
| "inputs": fields.Raw, | |||
| "query": fields.String, | |||
| "answer": fields.String(attribute="re_sign_file_url_answer"), | |||
| "message_files": fields.List(fields.Nested(message_file_fields), attribute="files"), | |||
| "feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True), | |||
| "retriever_resources": fields.List(fields.Nested(retriever_resource_fields)), | |||
| "created_at": TimestampField, | |||
| "agent_thoughts": fields.List(fields.Nested(agent_thought_fields)), | |||
| "status": fields.String, | |||
| "error": fields.String, | |||
| } | |||
| message_infinite_scroll_pagination_fields = { | |||
| 'limit': fields.Integer, | |||
| 'has_more': fields.Boolean, | |||
| 'data': fields.List(fields.Nested(message_fields)) | |||
| "limit": fields.Integer, | |||
| "has_more": fields.Boolean, | |||
| "data": fields.List(fields.Nested(message_fields)), | |||
| } | |||
| @marshal_with(message_infinite_scroll_pagination_fields) | |||
| @@ -84,14 +82,15 @@ class MessageListApi(WebApiResource): | |||
| raise NotChatAppError() | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('conversation_id', required=True, type=uuid_value, location='args') | |||
| parser.add_argument('first_id', type=uuid_value, location='args') | |||
| parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args') | |||
| parser.add_argument("conversation_id", required=True, type=uuid_value, location="args") | |||
| parser.add_argument("first_id", type=uuid_value, location="args") | |||
| parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") | |||
| args = parser.parse_args() | |||
| try: | |||
| return MessageService.pagination_by_first_id(app_model, end_user, | |||
| args['conversation_id'], args['first_id'], args['limit']) | |||
| return MessageService.pagination_by_first_id( | |||
| app_model, end_user, args["conversation_id"], args["first_id"], args["limit"] | |||
| ) | |||
| except services.errors.conversation.ConversationNotExistsError: | |||
| raise NotFound("Conversation Not Exists.") | |||
| except services.errors.message.FirstMessageNotExistsError: | |||
| @@ -103,29 +102,31 @@ class MessageFeedbackApi(WebApiResource): | |||
| message_id = str(message_id) | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('rating', type=str, choices=['like', 'dislike', None], location='json') | |||
| parser.add_argument("rating", type=str, choices=["like", "dislike", None], location="json") | |||
| args = parser.parse_args() | |||
| try: | |||
| MessageService.create_feedback(app_model, message_id, end_user, args['rating']) | |||
| MessageService.create_feedback(app_model, message_id, end_user, args["rating"]) | |||
| except services.errors.message.MessageNotExistsError: | |||
| raise NotFound("Message Not Exists.") | |||
| return {'result': 'success'} | |||
| return {"result": "success"} | |||
| class MessageMoreLikeThisApi(WebApiResource): | |||
| def get(self, app_model, end_user, message_id): | |||
| if app_model.mode != 'completion': | |||
| if app_model.mode != "completion": | |||
| raise NotCompletionAppError() | |||
| message_id = str(message_id) | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('response_mode', type=str, required=True, choices=['blocking', 'streaming'], location='args') | |||
| parser.add_argument( | |||
| "response_mode", type=str, required=True, choices=["blocking", "streaming"], location="args" | |||
| ) | |||
| args = parser.parse_args() | |||
| streaming = args['response_mode'] == 'streaming' | |||
| streaming = args["response_mode"] == "streaming" | |||
| try: | |||
| response = AppGenerateService.generate_more_like_this( | |||
| @@ -133,7 +134,7 @@ class MessageMoreLikeThisApi(WebApiResource): | |||
| user=end_user, | |||
| message_id=message_id, | |||
| invoke_from=InvokeFrom.WEB_APP, | |||
| streaming=streaming | |||
| streaming=streaming, | |||
| ) | |||
| return helper.compact_generate_response(response) | |||
| @@ -166,10 +167,7 @@ class MessageSuggestedQuestionApi(WebApiResource): | |||
| try: | |||
| questions = MessageService.get_suggested_questions_after_answer( | |||
| app_model=app_model, | |||
| user=end_user, | |||
| message_id=message_id, | |||
| invoke_from=InvokeFrom.WEB_APP | |||
| app_model=app_model, user=end_user, message_id=message_id, invoke_from=InvokeFrom.WEB_APP | |||
| ) | |||
| except MessageNotExistsError: | |||
| raise NotFound("Message not found") | |||
| @@ -189,10 +187,10 @@ class MessageSuggestedQuestionApi(WebApiResource): | |||
| logging.exception("internal server error.") | |||
| raise InternalServerError() | |||
| return {'data': questions} | |||
| return {"data": questions} | |||
| api.add_resource(MessageListApi, '/messages') | |||
| api.add_resource(MessageFeedbackApi, '/messages/<uuid:message_id>/feedbacks') | |||
| api.add_resource(MessageMoreLikeThisApi, '/messages/<uuid:message_id>/more-like-this') | |||
| api.add_resource(MessageSuggestedQuestionApi, '/messages/<uuid:message_id>/suggested-questions') | |||
| api.add_resource(MessageListApi, "/messages") | |||
| api.add_resource(MessageFeedbackApi, "/messages/<uuid:message_id>/feedbacks") | |||
| api.add_resource(MessageMoreLikeThisApi, "/messages/<uuid:message_id>/more-like-this") | |||
| api.add_resource(MessageSuggestedQuestionApi, "/messages/<uuid:message_id>/suggested-questions") | |||
| @@ -15,33 +15,31 @@ from services.feature_service import FeatureService | |||
| class PassportResource(Resource): | |||
| """Base resource for passport.""" | |||
| def get(self): | |||
| system_features = FeatureService.get_system_features() | |||
| app_code = request.headers.get('X-App-Code') | |||
| app_code = request.headers.get("X-App-Code") | |||
| if app_code is None: | |||
| raise Unauthorized('X-App-Code header is missing.') | |||
| raise Unauthorized("X-App-Code header is missing.") | |||
| if system_features.sso_enforced_for_web: | |||
| app_web_sso_enabled = EnterpriseService.get_app_web_sso_enabled(app_code).get('enabled', False) | |||
| app_web_sso_enabled = EnterpriseService.get_app_web_sso_enabled(app_code).get("enabled", False) | |||
| if app_web_sso_enabled: | |||
| raise WebSSOAuthRequiredError() | |||
| # get site from db and check if it is normal | |||
| site = db.session.query(Site).filter( | |||
| Site.code == app_code, | |||
| Site.status == 'normal' | |||
| ).first() | |||
| site = db.session.query(Site).filter(Site.code == app_code, Site.status == "normal").first() | |||
| if not site: | |||
| raise NotFound() | |||
| # get app from db and check if it is normal and enable_site | |||
| app_model = db.session.query(App).filter(App.id == site.app_id).first() | |||
| if not app_model or app_model.status != 'normal' or not app_model.enable_site: | |||
| if not app_model or app_model.status != "normal" or not app_model.enable_site: | |||
| raise NotFound() | |||
| end_user = EndUser( | |||
| tenant_id=app_model.tenant_id, | |||
| app_id=app_model.id, | |||
| type='browser', | |||
| type="browser", | |||
| is_anonymous=True, | |||
| session_id=generate_session_id(), | |||
| ) | |||
| @@ -51,20 +49,20 @@ class PassportResource(Resource): | |||
| payload = { | |||
| "iss": site.app_id, | |||
| 'sub': 'Web API Passport', | |||
| 'app_id': site.app_id, | |||
| 'app_code': app_code, | |||
| 'end_user_id': end_user.id, | |||
| "sub": "Web API Passport", | |||
| "app_id": site.app_id, | |||
| "app_code": app_code, | |||
| "end_user_id": end_user.id, | |||
| } | |||
| tk = PassportService().issue(payload) | |||
| return { | |||
| 'access_token': tk, | |||
| "access_token": tk, | |||
| } | |||
| api.add_resource(PassportResource, '/passport') | |||
| api.add_resource(PassportResource, "/passport") | |||
| def generate_session_id(): | |||
| @@ -73,7 +71,6 @@ def generate_session_id(): | |||
| """ | |||
| while True: | |||
| session_id = str(uuid.uuid4()) | |||
| existing_count = db.session.query(EndUser) \ | |||
| .filter(EndUser.session_id == session_id).count() | |||
| existing_count = db.session.query(EndUser).filter(EndUser.session_id == session_id).count() | |||
| if existing_count == 0: | |||
| return session_id | |||