| @@ -1,6 +1,6 @@ | |||
| import re | |||
| from core.app.app_config.entities import ExternalDataVariableEntity, VariableEntity | |||
| from core.app.app_config.entities import ExternalDataVariableEntity, VariableEntity, VariableEntityType | |||
| from core.external_data_tool.factory import ExternalDataToolFactory | |||
| @@ -13,7 +13,7 @@ class BasicVariablesConfigManager: | |||
| :param config: model config args | |||
| """ | |||
| external_data_variables = [] | |||
| variables = [] | |||
| variable_entities = [] | |||
| # old external_data_tools | |||
| external_data_tools = config.get('external_data_tools', []) | |||
| @@ -30,50 +30,41 @@ class BasicVariablesConfigManager: | |||
| ) | |||
| # variables and external_data_tools | |||
| for variable in config.get('user_input_form', []): | |||
| typ = list(variable.keys())[0] | |||
| if typ == 'external_data_tool': | |||
| val = variable[typ] | |||
| if 'config' not in val: | |||
| for variables in config.get('user_input_form', []): | |||
| variable_type = list(variables.keys())[0] | |||
| if variable_type == VariableEntityType.EXTERNAL_DATA_TOOL: | |||
| variable = variables[variable_type] | |||
| if 'config' not in variable: | |||
| continue | |||
| external_data_variables.append( | |||
| ExternalDataVariableEntity( | |||
| variable=val['variable'], | |||
| type=val['type'], | |||
| config=val['config'] | |||
| variable=variable['variable'], | |||
| type=variable['type'], | |||
| config=variable['config'] | |||
| ) | |||
| ) | |||
| elif typ in [ | |||
| VariableEntity.Type.TEXT_INPUT.value, | |||
| VariableEntity.Type.PARAGRAPH.value, | |||
| VariableEntity.Type.NUMBER.value, | |||
| elif variable_type in [ | |||
| VariableEntityType.TEXT_INPUT, | |||
| VariableEntityType.PARAGRAPH, | |||
| VariableEntityType.NUMBER, | |||
| VariableEntityType.SELECT, | |||
| ]: | |||
| variables.append( | |||
| VariableEntity( | |||
| type=VariableEntity.Type.value_of(typ), | |||
| variable=variable[typ].get('variable'), | |||
| description=variable[typ].get('description'), | |||
| label=variable[typ].get('label'), | |||
| required=variable[typ].get('required', False), | |||
| max_length=variable[typ].get('max_length'), | |||
| default=variable[typ].get('default'), | |||
| ) | |||
| ) | |||
| elif typ == VariableEntity.Type.SELECT.value: | |||
| variables.append( | |||
| variable = variables[variable_type] | |||
| variable_entities.append( | |||
| VariableEntity( | |||
| type=VariableEntity.Type.SELECT, | |||
| variable=variable[typ].get('variable'), | |||
| description=variable[typ].get('description'), | |||
| label=variable[typ].get('label'), | |||
| required=variable[typ].get('required', False), | |||
| options=variable[typ].get('options'), | |||
| default=variable[typ].get('default'), | |||
| type=variable_type, | |||
| variable=variable.get('variable'), | |||
| description=variable.get('description'), | |||
| label=variable.get('label'), | |||
| required=variable.get('required', False), | |||
| max_length=variable.get('max_length'), | |||
| options=variable.get('options'), | |||
| default=variable.get('default'), | |||
| ) | |||
| ) | |||
| return variables, external_data_variables | |||
| return variable_entities, external_data_variables | |||
| @classmethod | |||
| def validate_and_set_defaults(cls, tenant_id: str, config: dict) -> tuple[dict, list[str]]: | |||
| @@ -183,4 +174,4 @@ class BasicVariablesConfigManager: | |||
| config=config | |||
| ) | |||
| return config, ["external_data_tools"] | |||
| return config, ["external_data_tools"] | |||
| @@ -82,43 +82,29 @@ class PromptTemplateEntity(BaseModel): | |||
| advanced_completion_prompt_template: Optional[AdvancedCompletionPromptTemplateEntity] = None | |||
| class VariableEntityType(str, Enum): | |||
| TEXT_INPUT = "text-input" | |||
| SELECT = "select" | |||
| PARAGRAPH = "paragraph" | |||
| NUMBER = "number" | |||
| EXTERNAL_DATA_TOOL = "external-data-tool" | |||
| class VariableEntity(BaseModel): | |||
| """ | |||
| Variable Entity. | |||
| """ | |||
| class Type(Enum): | |||
| TEXT_INPUT = 'text-input' | |||
| SELECT = 'select' | |||
| PARAGRAPH = 'paragraph' | |||
| NUMBER = 'number' | |||
| @classmethod | |||
| def value_of(cls, value: str) -> 'VariableEntity.Type': | |||
| """ | |||
| Get value of given mode. | |||
| :param value: mode value | |||
| :return: mode | |||
| """ | |||
| for mode in cls: | |||
| if mode.value == value: | |||
| return mode | |||
| raise ValueError(f'invalid variable type value {value}') | |||
| variable: str | |||
| label: str | |||
| description: Optional[str] = None | |||
| type: Type | |||
| type: VariableEntityType | |||
| required: bool = False | |||
| max_length: Optional[int] = None | |||
| options: Optional[list[str]] = None | |||
| default: Optional[str] = None | |||
| hint: Optional[str] = None | |||
| @property | |||
| def name(self) -> str: | |||
| return self.variable | |||
| class ExternalDataVariableEntity(BaseModel): | |||
| """ | |||
| @@ -252,4 +238,4 @@ class WorkflowUIBasedAppConfig(AppConfig): | |||
| """ | |||
| Workflow UI Based App Config Entity. | |||
| """ | |||
| workflow_id: str | |||
| workflow_id: str | |||
| @@ -1,7 +1,7 @@ | |||
| from collections.abc import Mapping | |||
| from typing import Any, Optional | |||
| from core.app.app_config.entities import AppConfig, VariableEntity | |||
| from core.app.app_config.entities import AppConfig, VariableEntity, VariableEntityType | |||
| class BaseAppGenerator: | |||
| @@ -9,29 +9,29 @@ class BaseAppGenerator: | |||
| user_inputs = user_inputs or {} | |||
| # Filter input variables from form configuration, handle required fields, default values, and option values | |||
| variables = app_config.variables | |||
| filtered_inputs = {var.name: self._validate_input(inputs=user_inputs, var=var) for var in variables} | |||
| filtered_inputs = {var.variable: self._validate_input(inputs=user_inputs, var=var) for var in variables} | |||
| filtered_inputs = {k: self._sanitize_value(v) for k, v in filtered_inputs.items()} | |||
| return filtered_inputs | |||
| def _validate_input(self, *, inputs: Mapping[str, Any], var: VariableEntity): | |||
| user_input_value = inputs.get(var.name) | |||
| user_input_value = inputs.get(var.variable) | |||
| if var.required and not user_input_value: | |||
| raise ValueError(f'{var.name} is required in input form') | |||
| raise ValueError(f'{var.variable} is required in input form') | |||
| if not var.required and not user_input_value: | |||
| # TODO: should we return None here if the default value is None? | |||
| return var.default or '' | |||
| if ( | |||
| var.type | |||
| in ( | |||
| VariableEntity.Type.TEXT_INPUT, | |||
| VariableEntity.Type.SELECT, | |||
| VariableEntity.Type.PARAGRAPH, | |||
| VariableEntityType.TEXT_INPUT, | |||
| VariableEntityType.SELECT, | |||
| VariableEntityType.PARAGRAPH, | |||
| ) | |||
| and user_input_value | |||
| and not isinstance(user_input_value, str) | |||
| ): | |||
| raise ValueError(f"(type '{var.type}') {var.name} in input form must be a string") | |||
| if var.type == VariableEntity.Type.NUMBER and isinstance(user_input_value, str): | |||
| raise ValueError(f"(type '{var.type}') {var.variable} in input form must be a string") | |||
| if var.type == VariableEntityType.NUMBER and isinstance(user_input_value, str): | |||
| # may raise ValueError if user_input_value is not a valid number | |||
| try: | |||
| if '.' in user_input_value: | |||
| @@ -39,14 +39,14 @@ class BaseAppGenerator: | |||
| else: | |||
| return int(user_input_value) | |||
| except ValueError: | |||
| raise ValueError(f"{var.name} in input form must be a valid number") | |||
| if var.type == VariableEntity.Type.SELECT: | |||
| raise ValueError(f"{var.variable} in input form must be a valid number") | |||
| if var.type == VariableEntityType.SELECT: | |||
| options = var.options or [] | |||
| if user_input_value not in options: | |||
| raise ValueError(f'{var.name} in input form must be one of the following: {options}') | |||
| elif var.type in (VariableEntity.Type.TEXT_INPUT, VariableEntity.Type.PARAGRAPH): | |||
| raise ValueError(f'{var.variable} in input form must be one of the following: {options}') | |||
| elif var.type in (VariableEntityType.TEXT_INPUT, VariableEntityType.PARAGRAPH): | |||
| if var.max_length and user_input_value and len(user_input_value) > var.max_length: | |||
| raise ValueError(f'{var.name} in input form must be less than {var.max_length} characters') | |||
| raise ValueError(f'{var.variable} in input form must be less than {var.max_length} characters') | |||
| return user_input_value | |||
| @@ -1,6 +1,6 @@ | |||
| from typing import Optional | |||
| from core.app.app_config.entities import VariableEntity | |||
| from core.app.app_config.entities import VariableEntity, VariableEntityType | |||
| from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager | |||
| from core.tools.entities.common_entities import I18nObject | |||
| from core.tools.entities.tool_entities import ( | |||
| @@ -18,6 +18,13 @@ from models.model import App, AppMode | |||
| from models.tools import WorkflowToolProvider | |||
| from models.workflow import Workflow | |||
| VARIABLE_TO_PARAMETER_TYPE_MAPPING = { | |||
| VariableEntityType.TEXT_INPUT: ToolParameter.ToolParameterType.STRING, | |||
| VariableEntityType.PARAGRAPH: ToolParameter.ToolParameterType.STRING, | |||
| VariableEntityType.SELECT: ToolParameter.ToolParameterType.SELECT, | |||
| VariableEntityType.NUMBER: ToolParameter.ToolParameterType.NUMBER, | |||
| } | |||
| class WorkflowToolProviderController(ToolProviderController): | |||
| provider_id: str | |||
| @@ -28,7 +35,7 @@ class WorkflowToolProviderController(ToolProviderController): | |||
| if not app: | |||
| raise ValueError('app not found') | |||
| controller = WorkflowToolProviderController(**{ | |||
| 'identity': { | |||
| 'author': db_provider.user.name if db_provider.user_id and db_provider.user else '', | |||
| @@ -46,7 +53,7 @@ class WorkflowToolProviderController(ToolProviderController): | |||
| 'credentials_schema': {}, | |||
| 'provider_id': db_provider.id or '', | |||
| }) | |||
| # init tools | |||
| controller.tools = [controller._get_db_provider_tool(db_provider, app)] | |||
| @@ -56,7 +63,7 @@ class WorkflowToolProviderController(ToolProviderController): | |||
| @property | |||
| def provider_type(self) -> ToolProviderType: | |||
| return ToolProviderType.WORKFLOW | |||
| def _get_db_provider_tool(self, db_provider: WorkflowToolProvider, app: App) -> WorkflowTool: | |||
| """ | |||
| get db provider tool | |||
| @@ -93,23 +100,11 @@ class WorkflowToolProviderController(ToolProviderController): | |||
| if variable: | |||
| parameter_type = None | |||
| options = None | |||
| if variable.type in [ | |||
| VariableEntity.Type.TEXT_INPUT, | |||
| VariableEntity.Type.PARAGRAPH, | |||
| ]: | |||
| parameter_type = ToolParameter.ToolParameterType.STRING | |||
| elif variable.type in [ | |||
| VariableEntity.Type.SELECT | |||
| ]: | |||
| parameter_type = ToolParameter.ToolParameterType.SELECT | |||
| elif variable.type in [ | |||
| VariableEntity.Type.NUMBER | |||
| ]: | |||
| parameter_type = ToolParameter.ToolParameterType.NUMBER | |||
| else: | |||
| if variable.type not in VARIABLE_TO_PARAMETER_TYPE_MAPPING: | |||
| raise ValueError(f'unsupported variable type {variable.type}') | |||
| if variable.type == VariableEntity.Type.SELECT and variable.options: | |||
| parameter_type = VARIABLE_TO_PARAMETER_TYPE_MAPPING[variable.type] | |||
| if variable.type == VariableEntityType.SELECT and variable.options: | |||
| options = [ | |||
| ToolParameterOption( | |||
| value=option, | |||
| @@ -200,7 +195,7 @@ class WorkflowToolProviderController(ToolProviderController): | |||
| """ | |||
| if self.tools is not None: | |||
| return self.tools | |||
| db_providers: WorkflowToolProvider = db.session.query(WorkflowToolProvider).filter( | |||
| WorkflowToolProvider.tenant_id == tenant_id, | |||
| WorkflowToolProvider.app_id == self.provider_id, | |||
| @@ -208,11 +203,11 @@ class WorkflowToolProviderController(ToolProviderController): | |||
| if not db_providers: | |||
| return [] | |||
| self.tools = [self._get_db_provider_tool(db_providers, db_providers.app)] | |||
| return self.tools | |||
| def get_tool(self, tool_name: str) -> Optional[WorkflowTool]: | |||
| """ | |||
| get tool by name | |||
| @@ -226,5 +221,5 @@ class WorkflowToolProviderController(ToolProviderController): | |||
| for tool in self.tools: | |||
| if tool.identity.name == tool_name: | |||
| return tool | |||
| return None | |||
| @@ -1,3 +1,7 @@ | |||
| from collections.abc import Sequence | |||
| from pydantic import Field | |||
| from core.app.app_config.entities import VariableEntity | |||
| from core.workflow.entities.base_node_data_entities import BaseNodeData | |||
| @@ -6,4 +10,4 @@ class StartNodeData(BaseNodeData): | |||
| """ | |||
| Start Node Data | |||
| """ | |||
| variables: list[VariableEntity] = [] | |||
| variables: Sequence[VariableEntity] = Field(default_factory=list) | |||
| @@ -14,6 +14,7 @@ from core.app.app_config.entities import ( | |||
| ModelConfigEntity, | |||
| PromptTemplateEntity, | |||
| VariableEntity, | |||
| VariableEntityType, | |||
| ) | |||
| from core.helper import encrypter | |||
| from core.model_runtime.entities.llm_entities import LLMMode | |||
| @@ -25,23 +26,24 @@ from services.workflow.workflow_converter import WorkflowConverter | |||
| @pytest.fixture | |||
| def default_variables(): | |||
| return [ | |||
| value = [ | |||
| VariableEntity( | |||
| variable="text_input", | |||
| label="text-input", | |||
| type=VariableEntity.Type.TEXT_INPUT | |||
| type=VariableEntityType.TEXT_INPUT, | |||
| ), | |||
| VariableEntity( | |||
| variable="paragraph", | |||
| label="paragraph", | |||
| type=VariableEntity.Type.PARAGRAPH | |||
| type=VariableEntityType.PARAGRAPH, | |||
| ), | |||
| VariableEntity( | |||
| variable="select", | |||
| label="select", | |||
| type=VariableEntity.Type.SELECT | |||
| ) | |||
| type=VariableEntityType.SELECT, | |||
| ), | |||
| ] | |||
| return value | |||
| def test__convert_to_start_node(default_variables): | |||