| @@ -91,3 +91,9 @@ class DraftWorkflowNotExist(BaseHTTPException): | |||
| error_code = 'draft_workflow_not_exist' | |||
| description = "Draft workflow need to be initialized." | |||
| code = 400 | |||
| class DraftWorkflowNotSync(BaseHTTPException): | |||
| error_code = 'draft_workflow_not_sync' | |||
| description = "Workflow graph might have been modified, please refresh and resubmit." | |||
| code = 400 | |||
| @@ -7,7 +7,7 @@ from werkzeug.exceptions import InternalServerError, NotFound | |||
| import services | |||
| from controllers.console import api | |||
| from controllers.console.app.error import ConversationCompletedError, DraftWorkflowNotExist | |||
| from controllers.console.app.error import ConversationCompletedError, DraftWorkflowNotExist, DraftWorkflowNotSync | |||
| from controllers.console.app.wraps import get_app_model | |||
| from controllers.console.setup import setup_required | |||
| from controllers.console.wraps import account_initialization_required | |||
| @@ -20,6 +20,7 @@ from libs.helper import TimestampField, uuid_value | |||
| from libs.login import current_user, login_required | |||
| from models.model import App, AppMode | |||
| from services.app_generate_service import AppGenerateService | |||
| from services.errors.app import WorkflowHashNotEqualError | |||
| from services.workflow_service import WorkflowService | |||
| logger = logging.getLogger(__name__) | |||
| @@ -59,6 +60,7 @@ class DraftWorkflowApi(Resource): | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('graph', type=dict, required=True, nullable=False, location='json') | |||
| parser.add_argument('features', type=dict, required=True, nullable=False, location='json') | |||
| parser.add_argument('hash', type=str, required=False, location='json') | |||
| args = parser.parse_args() | |||
| elif 'text/plain' in content_type: | |||
| try: | |||
| @@ -71,7 +73,8 @@ class DraftWorkflowApi(Resource): | |||
| args = { | |||
| 'graph': data.get('graph'), | |||
| 'features': data.get('features') | |||
| 'features': data.get('features'), | |||
| 'hash': data.get('hash') | |||
| } | |||
| except json.JSONDecodeError: | |||
| return {'message': 'Invalid JSON data'}, 400 | |||
| @@ -79,15 +82,21 @@ class DraftWorkflowApi(Resource): | |||
| abort(415) | |||
| workflow_service = WorkflowService() | |||
| workflow = workflow_service.sync_draft_workflow( | |||
| app_model=app_model, | |||
| graph=args.get('graph'), | |||
| features=args.get('features'), | |||
| account=current_user | |||
| ) | |||
| try: | |||
| workflow = workflow_service.sync_draft_workflow( | |||
| app_model=app_model, | |||
| graph=args.get('graph'), | |||
| features=args.get('features'), | |||
| unique_hash=args.get('hash'), | |||
| account=current_user | |||
| ) | |||
| except WorkflowHashNotEqualError: | |||
| raise DraftWorkflowNotSync() | |||
| return { | |||
| "result": "success", | |||
| "hash": workflow.unique_hash, | |||
| "updated_at": TimestampField().format(workflow.updated_at or workflow.created_at) | |||
| } | |||
| @@ -7,6 +7,7 @@ workflow_fields = { | |||
| 'id': fields.String, | |||
| 'graph': fields.Raw(attribute='graph_dict'), | |||
| 'features': fields.Raw(attribute='features_dict'), | |||
| 'hash': fields.String(attribute='unique_hash'), | |||
| 'created_by': fields.Nested(simple_account_fields, attribute='created_by_account'), | |||
| 'created_at': TimestampField, | |||
| 'updated_by': fields.Nested(simple_account_fields, attribute='updated_by_account', allow_null=True), | |||
| @@ -4,6 +4,7 @@ from typing import Optional, Union | |||
| from core.tools.tool_manager import ToolManager | |||
| from extensions.ext_database import db | |||
| from libs import helper | |||
| from models import StringUUID | |||
| from models.account import Account | |||
| @@ -156,6 +157,21 @@ class Workflow(db.Model): | |||
| return variables | |||
| @property | |||
| def unique_hash(self) -> str: | |||
| """ | |||
| Get hash of workflow. | |||
| :return: hash | |||
| """ | |||
| entity = { | |||
| 'graph': self.graph_dict, | |||
| 'features': self.features_dict | |||
| } | |||
| return helper.generate_text_hash(json.dumps(entity, sort_keys=True)) | |||
| class WorkflowRunTriggeredFrom(Enum): | |||
| """ | |||
| Workflow Run Triggered From Enum | |||
| @@ -196,6 +196,7 @@ class AppService: | |||
| app_model=app, | |||
| graph=workflow.get('graph'), | |||
| features=workflow.get('features'), | |||
| unique_hash=None, | |||
| account=account | |||
| ) | |||
| workflow_service.publish_workflow( | |||
| @@ -1,2 +1,6 @@ | |||
| class MoreLikeThisDisabledError(Exception): | |||
| pass | |||
| class WorkflowHashNotEqualError(Exception): | |||
| pass | |||
| @@ -21,6 +21,7 @@ from models.workflow import ( | |||
| WorkflowNodeExecutionTriggeredFrom, | |||
| WorkflowType, | |||
| ) | |||
| from services.errors.app import WorkflowHashNotEqualError | |||
| from services.workflow.workflow_converter import WorkflowConverter | |||
| @@ -63,13 +64,20 @@ class WorkflowService: | |||
| def sync_draft_workflow(self, app_model: App, | |||
| graph: dict, | |||
| features: dict, | |||
| unique_hash: Optional[str], | |||
| account: Account) -> Workflow: | |||
| """ | |||
| Sync draft workflow | |||
| @throws WorkflowHashNotEqualError | |||
| """ | |||
| # fetch draft workflow by app_model | |||
| workflow = self.get_draft_workflow(app_model=app_model) | |||
| if workflow: | |||
| # validate unique hash | |||
| if workflow.unique_hash != unique_hash: | |||
| raise WorkflowHashNotEqualError() | |||
| # validate features structure | |||
| self.validate_features_structure( | |||
| app_model=app_model, | |||