| import re | import re | ||||
| from core.app.app_config.entities import ExternalDataVariableEntity, VariableEntity | |||||
| from core.app.app_config.entities import ExternalDataVariableEntity, VariableEntity, VariableEntityType | |||||
| from core.external_data_tool.factory import ExternalDataToolFactory | from core.external_data_tool.factory import ExternalDataToolFactory | ||||
| :param config: model config args | :param config: model config args | ||||
| """ | """ | ||||
| external_data_variables = [] | external_data_variables = [] | ||||
| variables = [] | |||||
| variable_entities = [] | |||||
| # old external_data_tools | # old external_data_tools | ||||
| external_data_tools = config.get('external_data_tools', []) | external_data_tools = config.get('external_data_tools', []) | ||||
| ) | ) | ||||
| # variables and external_data_tools | # variables and external_data_tools | ||||
| for variable in config.get('user_input_form', []): | |||||
| typ = list(variable.keys())[0] | |||||
| if typ == 'external_data_tool': | |||||
| val = variable[typ] | |||||
| if 'config' not in val: | |||||
| for variables in config.get('user_input_form', []): | |||||
| variable_type = list(variables.keys())[0] | |||||
| if variable_type == VariableEntityType.EXTERNAL_DATA_TOOL: | |||||
| variable = variables[variable_type] | |||||
| if 'config' not in variable: | |||||
| continue | continue | ||||
| external_data_variables.append( | external_data_variables.append( | ||||
| ExternalDataVariableEntity( | ExternalDataVariableEntity( | ||||
| variable=val['variable'], | |||||
| type=val['type'], | |||||
| config=val['config'] | |||||
| variable=variable['variable'], | |||||
| type=variable['type'], | |||||
| config=variable['config'] | |||||
| ) | ) | ||||
| ) | ) | ||||
| elif typ in [ | |||||
| VariableEntity.Type.TEXT_INPUT.value, | |||||
| VariableEntity.Type.PARAGRAPH.value, | |||||
| VariableEntity.Type.NUMBER.value, | |||||
| elif variable_type in [ | |||||
| VariableEntityType.TEXT_INPUT, | |||||
| VariableEntityType.PARAGRAPH, | |||||
| VariableEntityType.NUMBER, | |||||
| VariableEntityType.SELECT, | |||||
| ]: | ]: | ||||
| variables.append( | |||||
| VariableEntity( | |||||
| type=VariableEntity.Type.value_of(typ), | |||||
| variable=variable[typ].get('variable'), | |||||
| description=variable[typ].get('description'), | |||||
| label=variable[typ].get('label'), | |||||
| required=variable[typ].get('required', False), | |||||
| max_length=variable[typ].get('max_length'), | |||||
| default=variable[typ].get('default'), | |||||
| ) | |||||
| ) | |||||
| elif typ == VariableEntity.Type.SELECT.value: | |||||
| variables.append( | |||||
| variable = variables[variable_type] | |||||
| variable_entities.append( | |||||
| VariableEntity( | VariableEntity( | ||||
| type=VariableEntity.Type.SELECT, | |||||
| variable=variable[typ].get('variable'), | |||||
| description=variable[typ].get('description'), | |||||
| label=variable[typ].get('label'), | |||||
| required=variable[typ].get('required', False), | |||||
| options=variable[typ].get('options'), | |||||
| default=variable[typ].get('default'), | |||||
| type=variable_type, | |||||
| variable=variable.get('variable'), | |||||
| description=variable.get('description'), | |||||
| label=variable.get('label'), | |||||
| required=variable.get('required', False), | |||||
| max_length=variable.get('max_length'), | |||||
| options=variable.get('options'), | |||||
| default=variable.get('default'), | |||||
| ) | ) | ||||
| ) | ) | ||||
| return variables, external_data_variables | |||||
| return variable_entities, external_data_variables | |||||
| @classmethod | @classmethod | ||||
| def validate_and_set_defaults(cls, tenant_id: str, config: dict) -> tuple[dict, list[str]]: | def validate_and_set_defaults(cls, tenant_id: str, config: dict) -> tuple[dict, list[str]]: | ||||
| config=config | config=config | ||||
| ) | ) | ||||
| return config, ["external_data_tools"] | |||||
| return config, ["external_data_tools"] |
| advanced_completion_prompt_template: Optional[AdvancedCompletionPromptTemplateEntity] = None | advanced_completion_prompt_template: Optional[AdvancedCompletionPromptTemplateEntity] = None | ||||
| class VariableEntityType(str, Enum): | |||||
| TEXT_INPUT = "text-input" | |||||
| SELECT = "select" | |||||
| PARAGRAPH = "paragraph" | |||||
| NUMBER = "number" | |||||
| EXTERNAL_DATA_TOOL = "external-data-tool" | |||||
| class VariableEntity(BaseModel): | class VariableEntity(BaseModel): | ||||
| """ | """ | ||||
| Variable Entity. | Variable Entity. | ||||
| """ | """ | ||||
| class Type(Enum): | |||||
| TEXT_INPUT = 'text-input' | |||||
| SELECT = 'select' | |||||
| PARAGRAPH = 'paragraph' | |||||
| NUMBER = 'number' | |||||
| @classmethod | |||||
| def value_of(cls, value: str) -> 'VariableEntity.Type': | |||||
| """ | |||||
| Get value of given mode. | |||||
| :param value: mode value | |||||
| :return: mode | |||||
| """ | |||||
| for mode in cls: | |||||
| if mode.value == value: | |||||
| return mode | |||||
| raise ValueError(f'invalid variable type value {value}') | |||||
| variable: str | variable: str | ||||
| label: str | label: str | ||||
| description: Optional[str] = None | description: Optional[str] = None | ||||
| type: Type | |||||
| type: VariableEntityType | |||||
| required: bool = False | required: bool = False | ||||
| max_length: Optional[int] = None | max_length: Optional[int] = None | ||||
| options: Optional[list[str]] = None | options: Optional[list[str]] = None | ||||
| default: Optional[str] = None | default: Optional[str] = None | ||||
| hint: Optional[str] = None | hint: Optional[str] = None | ||||
| @property | |||||
| def name(self) -> str: | |||||
| return self.variable | |||||
| class ExternalDataVariableEntity(BaseModel): | class ExternalDataVariableEntity(BaseModel): | ||||
| """ | """ | ||||
| """ | """ | ||||
| Workflow UI Based App Config Entity. | Workflow UI Based App Config Entity. | ||||
| """ | """ | ||||
| workflow_id: str | |||||
| workflow_id: str |
| from collections.abc import Mapping | from collections.abc import Mapping | ||||
| from typing import Any, Optional | from typing import Any, Optional | ||||
| from core.app.app_config.entities import AppConfig, VariableEntity | |||||
| from core.app.app_config.entities import AppConfig, VariableEntity, VariableEntityType | |||||
| class BaseAppGenerator: | class BaseAppGenerator: | ||||
| user_inputs = user_inputs or {} | user_inputs = user_inputs or {} | ||||
| # Filter input variables from form configuration, handle required fields, default values, and option values | # Filter input variables from form configuration, handle required fields, default values, and option values | ||||
| variables = app_config.variables | variables = app_config.variables | ||||
| filtered_inputs = {var.name: self._validate_input(inputs=user_inputs, var=var) for var in variables} | |||||
| filtered_inputs = {var.variable: self._validate_input(inputs=user_inputs, var=var) for var in variables} | |||||
| filtered_inputs = {k: self._sanitize_value(v) for k, v in filtered_inputs.items()} | filtered_inputs = {k: self._sanitize_value(v) for k, v in filtered_inputs.items()} | ||||
| return filtered_inputs | return filtered_inputs | ||||
| def _validate_input(self, *, inputs: Mapping[str, Any], var: VariableEntity): | def _validate_input(self, *, inputs: Mapping[str, Any], var: VariableEntity): | ||||
| user_input_value = inputs.get(var.name) | |||||
| user_input_value = inputs.get(var.variable) | |||||
| if var.required and not user_input_value: | if var.required and not user_input_value: | ||||
| raise ValueError(f'{var.name} is required in input form') | |||||
| raise ValueError(f'{var.variable} is required in input form') | |||||
| if not var.required and not user_input_value: | if not var.required and not user_input_value: | ||||
| # TODO: should we return None here if the default value is None? | # TODO: should we return None here if the default value is None? | ||||
| return var.default or '' | return var.default or '' | ||||
| if ( | if ( | ||||
| var.type | var.type | ||||
| in ( | in ( | ||||
| VariableEntity.Type.TEXT_INPUT, | |||||
| VariableEntity.Type.SELECT, | |||||
| VariableEntity.Type.PARAGRAPH, | |||||
| VariableEntityType.TEXT_INPUT, | |||||
| VariableEntityType.SELECT, | |||||
| VariableEntityType.PARAGRAPH, | |||||
| ) | ) | ||||
| and user_input_value | and user_input_value | ||||
| and not isinstance(user_input_value, str) | and not isinstance(user_input_value, str) | ||||
| ): | ): | ||||
| raise ValueError(f"(type '{var.type}') {var.name} in input form must be a string") | |||||
| if var.type == VariableEntity.Type.NUMBER and isinstance(user_input_value, str): | |||||
| raise ValueError(f"(type '{var.type}') {var.variable} in input form must be a string") | |||||
| if var.type == VariableEntityType.NUMBER and isinstance(user_input_value, str): | |||||
| # may raise ValueError if user_input_value is not a valid number | # may raise ValueError if user_input_value is not a valid number | ||||
| try: | try: | ||||
| if '.' in user_input_value: | if '.' in user_input_value: | ||||
| else: | else: | ||||
| return int(user_input_value) | return int(user_input_value) | ||||
| except ValueError: | except ValueError: | ||||
| raise ValueError(f"{var.name} in input form must be a valid number") | |||||
| if var.type == VariableEntity.Type.SELECT: | |||||
| raise ValueError(f"{var.variable} in input form must be a valid number") | |||||
| if var.type == VariableEntityType.SELECT: | |||||
| options = var.options or [] | options = var.options or [] | ||||
| if user_input_value not in options: | if user_input_value not in options: | ||||
| raise ValueError(f'{var.name} in input form must be one of the following: {options}') | |||||
| elif var.type in (VariableEntity.Type.TEXT_INPUT, VariableEntity.Type.PARAGRAPH): | |||||
| raise ValueError(f'{var.variable} in input form must be one of the following: {options}') | |||||
| elif var.type in (VariableEntityType.TEXT_INPUT, VariableEntityType.PARAGRAPH): | |||||
| if var.max_length and user_input_value and len(user_input_value) > var.max_length: | if var.max_length and user_input_value and len(user_input_value) > var.max_length: | ||||
| raise ValueError(f'{var.name} in input form must be less than {var.max_length} characters') | |||||
| raise ValueError(f'{var.variable} in input form must be less than {var.max_length} characters') | |||||
| return user_input_value | return user_input_value | ||||
| from typing import Optional | from typing import Optional | ||||
| from core.app.app_config.entities import VariableEntity | |||||
| from core.app.app_config.entities import VariableEntity, VariableEntityType | |||||
| from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager | from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager | ||||
| from core.tools.entities.common_entities import I18nObject | from core.tools.entities.common_entities import I18nObject | ||||
| from core.tools.entities.tool_entities import ( | from core.tools.entities.tool_entities import ( | ||||
| from models.tools import WorkflowToolProvider | from models.tools import WorkflowToolProvider | ||||
| from models.workflow import Workflow | from models.workflow import Workflow | ||||
| VARIABLE_TO_PARAMETER_TYPE_MAPPING = { | |||||
| VariableEntityType.TEXT_INPUT: ToolParameter.ToolParameterType.STRING, | |||||
| VariableEntityType.PARAGRAPH: ToolParameter.ToolParameterType.STRING, | |||||
| VariableEntityType.SELECT: ToolParameter.ToolParameterType.SELECT, | |||||
| VariableEntityType.NUMBER: ToolParameter.ToolParameterType.NUMBER, | |||||
| } | |||||
| class WorkflowToolProviderController(ToolProviderController): | class WorkflowToolProviderController(ToolProviderController): | ||||
| provider_id: str | provider_id: str | ||||
| if not app: | if not app: | ||||
| raise ValueError('app not found') | raise ValueError('app not found') | ||||
| controller = WorkflowToolProviderController(**{ | controller = WorkflowToolProviderController(**{ | ||||
| 'identity': { | 'identity': { | ||||
| 'author': db_provider.user.name if db_provider.user_id and db_provider.user else '', | 'author': db_provider.user.name if db_provider.user_id and db_provider.user else '', | ||||
| 'credentials_schema': {}, | 'credentials_schema': {}, | ||||
| 'provider_id': db_provider.id or '', | 'provider_id': db_provider.id or '', | ||||
| }) | }) | ||||
| # init tools | # init tools | ||||
| controller.tools = [controller._get_db_provider_tool(db_provider, app)] | controller.tools = [controller._get_db_provider_tool(db_provider, app)] | ||||
| @property | @property | ||||
| def provider_type(self) -> ToolProviderType: | def provider_type(self) -> ToolProviderType: | ||||
| return ToolProviderType.WORKFLOW | return ToolProviderType.WORKFLOW | ||||
| def _get_db_provider_tool(self, db_provider: WorkflowToolProvider, app: App) -> WorkflowTool: | def _get_db_provider_tool(self, db_provider: WorkflowToolProvider, app: App) -> WorkflowTool: | ||||
| """ | """ | ||||
| get db provider tool | get db provider tool | ||||
| if variable: | if variable: | ||||
| parameter_type = None | parameter_type = None | ||||
| options = None | options = None | ||||
| if variable.type in [ | |||||
| VariableEntity.Type.TEXT_INPUT, | |||||
| VariableEntity.Type.PARAGRAPH, | |||||
| ]: | |||||
| parameter_type = ToolParameter.ToolParameterType.STRING | |||||
| elif variable.type in [ | |||||
| VariableEntity.Type.SELECT | |||||
| ]: | |||||
| parameter_type = ToolParameter.ToolParameterType.SELECT | |||||
| elif variable.type in [ | |||||
| VariableEntity.Type.NUMBER | |||||
| ]: | |||||
| parameter_type = ToolParameter.ToolParameterType.NUMBER | |||||
| else: | |||||
| if variable.type not in VARIABLE_TO_PARAMETER_TYPE_MAPPING: | |||||
| raise ValueError(f'unsupported variable type {variable.type}') | raise ValueError(f'unsupported variable type {variable.type}') | ||||
| if variable.type == VariableEntity.Type.SELECT and variable.options: | |||||
| parameter_type = VARIABLE_TO_PARAMETER_TYPE_MAPPING[variable.type] | |||||
| if variable.type == VariableEntityType.SELECT and variable.options: | |||||
| options = [ | options = [ | ||||
| ToolParameterOption( | ToolParameterOption( | ||||
| value=option, | value=option, | ||||
| """ | """ | ||||
| if self.tools is not None: | if self.tools is not None: | ||||
| return self.tools | return self.tools | ||||
| db_providers: WorkflowToolProvider = db.session.query(WorkflowToolProvider).filter( | db_providers: WorkflowToolProvider = db.session.query(WorkflowToolProvider).filter( | ||||
| WorkflowToolProvider.tenant_id == tenant_id, | WorkflowToolProvider.tenant_id == tenant_id, | ||||
| WorkflowToolProvider.app_id == self.provider_id, | WorkflowToolProvider.app_id == self.provider_id, | ||||
| if not db_providers: | if not db_providers: | ||||
| return [] | return [] | ||||
| self.tools = [self._get_db_provider_tool(db_providers, db_providers.app)] | self.tools = [self._get_db_provider_tool(db_providers, db_providers.app)] | ||||
| return self.tools | return self.tools | ||||
| def get_tool(self, tool_name: str) -> Optional[WorkflowTool]: | def get_tool(self, tool_name: str) -> Optional[WorkflowTool]: | ||||
| """ | """ | ||||
| get tool by name | get tool by name | ||||
| for tool in self.tools: | for tool in self.tools: | ||||
| if tool.identity.name == tool_name: | if tool.identity.name == tool_name: | ||||
| return tool | return tool | ||||
| return None | return None |
| from collections.abc import Sequence | |||||
| from pydantic import Field | |||||
| from core.app.app_config.entities import VariableEntity | from core.app.app_config.entities import VariableEntity | ||||
| from core.workflow.entities.base_node_data_entities import BaseNodeData | from core.workflow.entities.base_node_data_entities import BaseNodeData | ||||
| """ | """ | ||||
| Start Node Data | Start Node Data | ||||
| """ | """ | ||||
| variables: list[VariableEntity] = [] | |||||
| variables: Sequence[VariableEntity] = Field(default_factory=list) |
| ModelConfigEntity, | ModelConfigEntity, | ||||
| PromptTemplateEntity, | PromptTemplateEntity, | ||||
| VariableEntity, | VariableEntity, | ||||
| VariableEntityType, | |||||
| ) | ) | ||||
| from core.helper import encrypter | from core.helper import encrypter | ||||
| from core.model_runtime.entities.llm_entities import LLMMode | from core.model_runtime.entities.llm_entities import LLMMode | ||||
| @pytest.fixture | @pytest.fixture | ||||
| def default_variables(): | def default_variables(): | ||||
| return [ | |||||
| value = [ | |||||
| VariableEntity( | VariableEntity( | ||||
| variable="text_input", | variable="text_input", | ||||
| label="text-input", | label="text-input", | ||||
| type=VariableEntity.Type.TEXT_INPUT | |||||
| type=VariableEntityType.TEXT_INPUT, | |||||
| ), | ), | ||||
| VariableEntity( | VariableEntity( | ||||
| variable="paragraph", | variable="paragraph", | ||||
| label="paragraph", | label="paragraph", | ||||
| type=VariableEntity.Type.PARAGRAPH | |||||
| type=VariableEntityType.PARAGRAPH, | |||||
| ), | ), | ||||
| VariableEntity( | VariableEntity( | ||||
| variable="select", | variable="select", | ||||
| label="select", | label="select", | ||||
| type=VariableEntity.Type.SELECT | |||||
| ) | |||||
| type=VariableEntityType.SELECT, | |||||
| ), | |||||
| ] | ] | ||||
| return value | |||||
| def test__convert_to_start_node(default_variables): | def test__convert_to_start_node(default_variables): |