소스 검색

refactor(api/core/app/app_config/entities.py): Move Type to outside and add EXTERNAL_DATA_TOOL. (#7444)

tags/0.7.2
-LAN- 1 년 전
부모
커밋
a10b207de2
No account linked to committer's email address

+ 27
- 36
api/core/app/app_config/easy_ui_based_app/variables/manager.py 파일 보기

import re import re


from core.app.app_config.entities import ExternalDataVariableEntity, VariableEntity
from core.app.app_config.entities import ExternalDataVariableEntity, VariableEntity, VariableEntityType
from core.external_data_tool.factory import ExternalDataToolFactory from core.external_data_tool.factory import ExternalDataToolFactory




:param config: model config args :param config: model config args
""" """
external_data_variables = [] external_data_variables = []
variables = []
variable_entities = []


# old external_data_tools # old external_data_tools
external_data_tools = config.get('external_data_tools', []) external_data_tools = config.get('external_data_tools', [])
) )


# variables and external_data_tools # variables and external_data_tools
for variable in config.get('user_input_form', []):
typ = list(variable.keys())[0]
if typ == 'external_data_tool':
val = variable[typ]
if 'config' not in val:
for variables in config.get('user_input_form', []):
variable_type = list(variables.keys())[0]
if variable_type == VariableEntityType.EXTERNAL_DATA_TOOL:
variable = variables[variable_type]
if 'config' not in variable:
continue continue


external_data_variables.append( external_data_variables.append(
ExternalDataVariableEntity( ExternalDataVariableEntity(
variable=val['variable'],
type=val['type'],
config=val['config']
variable=variable['variable'],
type=variable['type'],
config=variable['config']
) )
) )
elif typ in [
VariableEntity.Type.TEXT_INPUT.value,
VariableEntity.Type.PARAGRAPH.value,
VariableEntity.Type.NUMBER.value,
elif variable_type in [
VariableEntityType.TEXT_INPUT,
VariableEntityType.PARAGRAPH,
VariableEntityType.NUMBER,
VariableEntityType.SELECT,
]: ]:
variables.append(
VariableEntity(
type=VariableEntity.Type.value_of(typ),
variable=variable[typ].get('variable'),
description=variable[typ].get('description'),
label=variable[typ].get('label'),
required=variable[typ].get('required', False),
max_length=variable[typ].get('max_length'),
default=variable[typ].get('default'),
)
)
elif typ == VariableEntity.Type.SELECT.value:
variables.append(
variable = variables[variable_type]
variable_entities.append(
VariableEntity( VariableEntity(
type=VariableEntity.Type.SELECT,
variable=variable[typ].get('variable'),
description=variable[typ].get('description'),
label=variable[typ].get('label'),
required=variable[typ].get('required', False),
options=variable[typ].get('options'),
default=variable[typ].get('default'),
type=variable_type,
variable=variable.get('variable'),
description=variable.get('description'),
label=variable.get('label'),
required=variable.get('required', False),
max_length=variable.get('max_length'),
options=variable.get('options'),
default=variable.get('default'),
) )
) )


return variables, external_data_variables
return variable_entities, external_data_variables


@classmethod @classmethod
def validate_and_set_defaults(cls, tenant_id: str, config: dict) -> tuple[dict, list[str]]: def validate_and_set_defaults(cls, tenant_id: str, config: dict) -> tuple[dict, list[str]]:
config=config config=config
) )


return config, ["external_data_tools"]
return config, ["external_data_tools"]

+ 10
- 24
api/core/app/app_config/entities.py 파일 보기

advanced_completion_prompt_template: Optional[AdvancedCompletionPromptTemplateEntity] = None advanced_completion_prompt_template: Optional[AdvancedCompletionPromptTemplateEntity] = None




class VariableEntityType(str, Enum):
TEXT_INPUT = "text-input"
SELECT = "select"
PARAGRAPH = "paragraph"
NUMBER = "number"
EXTERNAL_DATA_TOOL = "external-data-tool"


class VariableEntity(BaseModel): class VariableEntity(BaseModel):
""" """
Variable Entity. Variable Entity.
""" """
class Type(Enum):
TEXT_INPUT = 'text-input'
SELECT = 'select'
PARAGRAPH = 'paragraph'
NUMBER = 'number'

@classmethod
def value_of(cls, value: str) -> 'VariableEntity.Type':
"""
Get value of given mode.

:param value: mode value
:return: mode
"""
for mode in cls:
if mode.value == value:
return mode
raise ValueError(f'invalid variable type value {value}')


variable: str variable: str
label: str label: str
description: Optional[str] = None description: Optional[str] = None
type: Type
type: VariableEntityType
required: bool = False required: bool = False
max_length: Optional[int] = None max_length: Optional[int] = None
options: Optional[list[str]] = None options: Optional[list[str]] = None
default: Optional[str] = None default: Optional[str] = None
hint: Optional[str] = None hint: Optional[str] = None


@property
def name(self) -> str:
return self.variable



class ExternalDataVariableEntity(BaseModel): class ExternalDataVariableEntity(BaseModel):
""" """
""" """
Workflow UI Based App Config Entity. Workflow UI Based App Config Entity.
""" """
workflow_id: str
workflow_id: str

+ 14
- 14
api/core/app/apps/base_app_generator.py 파일 보기

from collections.abc import Mapping from collections.abc import Mapping
from typing import Any, Optional from typing import Any, Optional


from core.app.app_config.entities import AppConfig, VariableEntity
from core.app.app_config.entities import AppConfig, VariableEntity, VariableEntityType




class BaseAppGenerator: class BaseAppGenerator:
user_inputs = user_inputs or {} user_inputs = user_inputs or {}
# Filter input variables from form configuration, handle required fields, default values, and option values # Filter input variables from form configuration, handle required fields, default values, and option values
variables = app_config.variables variables = app_config.variables
filtered_inputs = {var.name: self._validate_input(inputs=user_inputs, var=var) for var in variables}
filtered_inputs = {var.variable: self._validate_input(inputs=user_inputs, var=var) for var in variables}
filtered_inputs = {k: self._sanitize_value(v) for k, v in filtered_inputs.items()} filtered_inputs = {k: self._sanitize_value(v) for k, v in filtered_inputs.items()}
return filtered_inputs return filtered_inputs


def _validate_input(self, *, inputs: Mapping[str, Any], var: VariableEntity): def _validate_input(self, *, inputs: Mapping[str, Any], var: VariableEntity):
user_input_value = inputs.get(var.name)
user_input_value = inputs.get(var.variable)
if var.required and not user_input_value: if var.required and not user_input_value:
raise ValueError(f'{var.name} is required in input form')
raise ValueError(f'{var.variable} is required in input form')
if not var.required and not user_input_value: if not var.required and not user_input_value:
# TODO: should we return None here if the default value is None? # TODO: should we return None here if the default value is None?
return var.default or '' return var.default or ''
if ( if (
var.type var.type
in ( in (
VariableEntity.Type.TEXT_INPUT,
VariableEntity.Type.SELECT,
VariableEntity.Type.PARAGRAPH,
VariableEntityType.TEXT_INPUT,
VariableEntityType.SELECT,
VariableEntityType.PARAGRAPH,
) )
and user_input_value and user_input_value
and not isinstance(user_input_value, str) and not isinstance(user_input_value, str)
): ):
raise ValueError(f"(type '{var.type}') {var.name} in input form must be a string")
if var.type == VariableEntity.Type.NUMBER and isinstance(user_input_value, str):
raise ValueError(f"(type '{var.type}') {var.variable} in input form must be a string")
if var.type == VariableEntityType.NUMBER and isinstance(user_input_value, str):
# may raise ValueError if user_input_value is not a valid number # may raise ValueError if user_input_value is not a valid number
try: try:
if '.' in user_input_value: if '.' in user_input_value:
else: else:
return int(user_input_value) return int(user_input_value)
except ValueError: except ValueError:
raise ValueError(f"{var.name} in input form must be a valid number")
if var.type == VariableEntity.Type.SELECT:
raise ValueError(f"{var.variable} in input form must be a valid number")
if var.type == VariableEntityType.SELECT:
options = var.options or [] options = var.options or []
if user_input_value not in options: if user_input_value not in options:
raise ValueError(f'{var.name} in input form must be one of the following: {options}')
elif var.type in (VariableEntity.Type.TEXT_INPUT, VariableEntity.Type.PARAGRAPH):
raise ValueError(f'{var.variable} in input form must be one of the following: {options}')
elif var.type in (VariableEntityType.TEXT_INPUT, VariableEntityType.PARAGRAPH):
if var.max_length and user_input_value and len(user_input_value) > var.max_length: if var.max_length and user_input_value and len(user_input_value) > var.max_length:
raise ValueError(f'{var.name} in input form must be less than {var.max_length} characters')
raise ValueError(f'{var.variable} in input form must be less than {var.max_length} characters')


return user_input_value return user_input_value



+ 19
- 24
api/core/tools/provider/workflow_tool_provider.py 파일 보기

from typing import Optional from typing import Optional


from core.app.app_config.entities import VariableEntity
from core.app.app_config.entities import VariableEntity, VariableEntityType
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
from core.tools.entities.common_entities import I18nObject from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ( from core.tools.entities.tool_entities import (
from models.tools import WorkflowToolProvider from models.tools import WorkflowToolProvider
from models.workflow import Workflow from models.workflow import Workflow


VARIABLE_TO_PARAMETER_TYPE_MAPPING = {
VariableEntityType.TEXT_INPUT: ToolParameter.ToolParameterType.STRING,
VariableEntityType.PARAGRAPH: ToolParameter.ToolParameterType.STRING,
VariableEntityType.SELECT: ToolParameter.ToolParameterType.SELECT,
VariableEntityType.NUMBER: ToolParameter.ToolParameterType.NUMBER,
}



class WorkflowToolProviderController(ToolProviderController): class WorkflowToolProviderController(ToolProviderController):
provider_id: str provider_id: str


if not app: if not app:
raise ValueError('app not found') raise ValueError('app not found')
controller = WorkflowToolProviderController(**{ controller = WorkflowToolProviderController(**{
'identity': { 'identity': {
'author': db_provider.user.name if db_provider.user_id and db_provider.user else '', 'author': db_provider.user.name if db_provider.user_id and db_provider.user else '',
'credentials_schema': {}, 'credentials_schema': {},
'provider_id': db_provider.id or '', 'provider_id': db_provider.id or '',
}) })
# init tools # init tools


controller.tools = [controller._get_db_provider_tool(db_provider, app)] controller.tools = [controller._get_db_provider_tool(db_provider, app)]
@property @property
def provider_type(self) -> ToolProviderType: def provider_type(self) -> ToolProviderType:
return ToolProviderType.WORKFLOW return ToolProviderType.WORKFLOW
def _get_db_provider_tool(self, db_provider: WorkflowToolProvider, app: App) -> WorkflowTool: def _get_db_provider_tool(self, db_provider: WorkflowToolProvider, app: App) -> WorkflowTool:
""" """
get db provider tool get db provider tool
if variable: if variable:
parameter_type = None parameter_type = None
options = None options = None
if variable.type in [
VariableEntity.Type.TEXT_INPUT,
VariableEntity.Type.PARAGRAPH,
]:
parameter_type = ToolParameter.ToolParameterType.STRING
elif variable.type in [
VariableEntity.Type.SELECT
]:
parameter_type = ToolParameter.ToolParameterType.SELECT
elif variable.type in [
VariableEntity.Type.NUMBER
]:
parameter_type = ToolParameter.ToolParameterType.NUMBER
else:
if variable.type not in VARIABLE_TO_PARAMETER_TYPE_MAPPING:
raise ValueError(f'unsupported variable type {variable.type}') raise ValueError(f'unsupported variable type {variable.type}')
if variable.type == VariableEntity.Type.SELECT and variable.options:
parameter_type = VARIABLE_TO_PARAMETER_TYPE_MAPPING[variable.type]

if variable.type == VariableEntityType.SELECT and variable.options:
options = [ options = [
ToolParameterOption( ToolParameterOption(
value=option, value=option,
""" """
if self.tools is not None: if self.tools is not None:
return self.tools return self.tools
db_providers: WorkflowToolProvider = db.session.query(WorkflowToolProvider).filter( db_providers: WorkflowToolProvider = db.session.query(WorkflowToolProvider).filter(
WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.tenant_id == tenant_id,
WorkflowToolProvider.app_id == self.provider_id, WorkflowToolProvider.app_id == self.provider_id,


if not db_providers: if not db_providers:
return [] return []
self.tools = [self._get_db_provider_tool(db_providers, db_providers.app)] self.tools = [self._get_db_provider_tool(db_providers, db_providers.app)]


return self.tools return self.tools
def get_tool(self, tool_name: str) -> Optional[WorkflowTool]: def get_tool(self, tool_name: str) -> Optional[WorkflowTool]:
""" """
get tool by name get tool by name
for tool in self.tools: for tool in self.tools:
if tool.identity.name == tool_name: if tool.identity.name == tool_name:
return tool return tool
return None return None

+ 5
- 1
api/core/workflow/nodes/start/entities.py 파일 보기

from collections.abc import Sequence

from pydantic import Field

from core.app.app_config.entities import VariableEntity from core.app.app_config.entities import VariableEntity
from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.base_node_data_entities import BaseNodeData


""" """
Start Node Data Start Node Data
""" """
variables: list[VariableEntity] = []
variables: Sequence[VariableEntity] = Field(default_factory=list)

+ 7
- 5
api/tests/unit_tests/services/workflow/test_workflow_converter.py 파일 보기

ModelConfigEntity, ModelConfigEntity,
PromptTemplateEntity, PromptTemplateEntity,
VariableEntity, VariableEntity,
VariableEntityType,
) )
from core.helper import encrypter from core.helper import encrypter
from core.model_runtime.entities.llm_entities import LLMMode from core.model_runtime.entities.llm_entities import LLMMode


@pytest.fixture @pytest.fixture
def default_variables(): def default_variables():
return [
value = [
VariableEntity( VariableEntity(
variable="text_input", variable="text_input",
label="text-input", label="text-input",
type=VariableEntity.Type.TEXT_INPUT
type=VariableEntityType.TEXT_INPUT,
), ),
VariableEntity( VariableEntity(
variable="paragraph", variable="paragraph",
label="paragraph", label="paragraph",
type=VariableEntity.Type.PARAGRAPH
type=VariableEntityType.PARAGRAPH,
), ),
VariableEntity( VariableEntity(
variable="select", variable="select",
label="select", label="select",
type=VariableEntity.Type.SELECT
)
type=VariableEntityType.SELECT,
),
] ]
return value




def test__convert_to_start_node(default_variables): def test__convert_to_start_node(default_variables):

Loading…
취소
저장