|
|
|
@@ -56,6 +56,15 @@ from .entities import ( |
|
|
|
LLMNodeData, |
|
|
|
ModelConfig, |
|
|
|
) |
|
|
|
from .exc import ( |
|
|
|
InvalidContextStructureError, |
|
|
|
InvalidVariableTypeError, |
|
|
|
LLMModeRequiredError, |
|
|
|
LLMNodeError, |
|
|
|
ModelNotExistError, |
|
|
|
NoPromptFoundError, |
|
|
|
VariableNotFoundError, |
|
|
|
) |
|
|
|
|
|
|
|
if TYPE_CHECKING: |
|
|
|
from core.file.models import File |
|
|
|
@@ -115,7 +124,7 @@ class LLMNode(BaseNode[LLMNodeData]): |
|
|
|
if self.node_data.memory: |
|
|
|
query = self.graph_runtime_state.variable_pool.get((SYSTEM_VARIABLE_NODE_ID, SystemVariableKey.QUERY)) |
|
|
|
if not query: |
|
|
|
raise ValueError("Query not found") |
|
|
|
raise VariableNotFoundError("Query not found") |
|
|
|
query = query.text |
|
|
|
else: |
|
|
|
query = None |
|
|
|
@@ -161,7 +170,7 @@ class LLMNode(BaseNode[LLMNodeData]): |
|
|
|
usage = event.usage |
|
|
|
finish_reason = event.finish_reason |
|
|
|
break |
|
|
|
except Exception as e: |
|
|
|
except LLMNodeError as e: |
|
|
|
yield RunCompletedEvent( |
|
|
|
run_result=NodeRunResult( |
|
|
|
status=WorkflowNodeExecutionStatus.FAILED, |
|
|
|
@@ -275,7 +284,7 @@ class LLMNode(BaseNode[LLMNodeData]): |
|
|
|
variable_name = variable_selector.variable |
|
|
|
variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector) |
|
|
|
if variable is None: |
|
|
|
raise ValueError(f"Variable {variable_selector.variable} not found") |
|
|
|
raise VariableNotFoundError(f"Variable {variable_selector.variable} not found") |
|
|
|
|
|
|
|
def parse_dict(input_dict: Mapping[str, Any]) -> str: |
|
|
|
""" |
|
|
|
@@ -325,7 +334,7 @@ class LLMNode(BaseNode[LLMNodeData]): |
|
|
|
for variable_selector in variable_selectors: |
|
|
|
variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector) |
|
|
|
if variable is None: |
|
|
|
raise ValueError(f"Variable {variable_selector.variable} not found") |
|
|
|
raise VariableNotFoundError(f"Variable {variable_selector.variable} not found") |
|
|
|
if isinstance(variable, NoneSegment): |
|
|
|
inputs[variable_selector.variable] = "" |
|
|
|
inputs[variable_selector.variable] = variable.to_object() |
|
|
|
@@ -338,7 +347,7 @@ class LLMNode(BaseNode[LLMNodeData]): |
|
|
|
for variable_selector in query_variable_selectors: |
|
|
|
variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector) |
|
|
|
if variable is None: |
|
|
|
raise ValueError(f"Variable {variable_selector.variable} not found") |
|
|
|
raise VariableNotFoundError(f"Variable {variable_selector.variable} not found") |
|
|
|
if isinstance(variable, NoneSegment): |
|
|
|
continue |
|
|
|
inputs[variable_selector.variable] = variable.to_object() |
|
|
|
@@ -355,7 +364,7 @@ class LLMNode(BaseNode[LLMNodeData]): |
|
|
|
return variable.value |
|
|
|
elif isinstance(variable, NoneSegment | ArrayAnySegment): |
|
|
|
return [] |
|
|
|
raise ValueError(f"Invalid variable type: {type(variable)}") |
|
|
|
raise InvalidVariableTypeError(f"Invalid variable type: {type(variable)}") |
|
|
|
|
|
|
|
def _fetch_context(self, node_data: LLMNodeData): |
|
|
|
if not node_data.context.enabled: |
|
|
|
@@ -376,7 +385,7 @@ class LLMNode(BaseNode[LLMNodeData]): |
|
|
|
context_str += item + "\n" |
|
|
|
else: |
|
|
|
if "content" not in item: |
|
|
|
raise ValueError(f"Invalid context structure: {item}") |
|
|
|
raise InvalidContextStructureError(f"Invalid context structure: {item}") |
|
|
|
|
|
|
|
context_str += item["content"] + "\n" |
|
|
|
|
|
|
|
@@ -441,7 +450,7 @@ class LLMNode(BaseNode[LLMNodeData]): |
|
|
|
) |
|
|
|
|
|
|
|
if provider_model is None: |
|
|
|
raise ValueError(f"Model {model_name} not exist.") |
|
|
|
raise ModelNotExistError(f"Model {model_name} not exist.") |
|
|
|
|
|
|
|
if provider_model.status == ModelStatus.NO_CONFIGURE: |
|
|
|
raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.") |
|
|
|
@@ -460,12 +469,12 @@ class LLMNode(BaseNode[LLMNodeData]): |
|
|
|
# get model mode |
|
|
|
model_mode = node_data_model.mode |
|
|
|
if not model_mode: |
|
|
|
raise ValueError("LLM mode is required.") |
|
|
|
raise LLMModeRequiredError("LLM mode is required.") |
|
|
|
|
|
|
|
model_schema = model_type_instance.get_model_schema(model_name, model_credentials) |
|
|
|
|
|
|
|
if not model_schema: |
|
|
|
raise ValueError(f"Model {model_name} not exist.") |
|
|
|
raise ModelNotExistError(f"Model {model_name} not exist.") |
|
|
|
|
|
|
|
return model_instance, ModelConfigWithCredentialsEntity( |
|
|
|
provider=provider_name, |
|
|
|
@@ -564,7 +573,7 @@ class LLMNode(BaseNode[LLMNodeData]): |
|
|
|
filtered_prompt_messages.append(prompt_message) |
|
|
|
|
|
|
|
if not filtered_prompt_messages: |
|
|
|
raise ValueError( |
|
|
|
raise NoPromptFoundError( |
|
|
|
"No prompt found in the LLM configuration. " |
|
|
|
"Please ensure a prompt is properly configured before proceeding." |
|
|
|
) |
|
|
|
@@ -636,7 +645,7 @@ class LLMNode(BaseNode[LLMNodeData]): |
|
|
|
variable_template_parser = VariableTemplateParser(template=prompt_template.text) |
|
|
|
variable_selectors = variable_template_parser.extract_variable_selectors() |
|
|
|
else: |
|
|
|
raise ValueError(f"Invalid prompt template type: {type(prompt_template)}") |
|
|
|
raise InvalidVariableTypeError(f"Invalid prompt template type: {type(prompt_template)}") |
|
|
|
|
|
|
|
variable_mapping = {} |
|
|
|
for variable_selector in variable_selectors: |