|
|
|
@@ -175,7 +175,9 @@ class KnowledgeRetrievalNode(LLMNode): |
|
|
|
dataset_retrieval = DatasetRetrieval() |
|
|
|
if node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE.value: |
|
|
|
# fetch model config |
|
|
|
model_instance, model_config = self._fetch_model_config(node_data.single_retrieval_config.model) # type: ignore |
|
|
|
if node_data.single_retrieval_config is None: |
|
|
|
raise ValueError("single_retrieval_config is required") |
|
|
|
model_instance, model_config = self.get_model_config(node_data.single_retrieval_config.model) |
|
|
|
# check model is support tool calling |
|
|
|
model_type_instance = model_config.provider_model_bundle.model_type_instance |
|
|
|
model_type_instance = cast(LargeLanguageModel, model_type_instance) |
|
|
|
@@ -426,7 +428,7 @@ class KnowledgeRetrievalNode(LLMNode): |
|
|
|
raise ValueError("metadata_model_config is required") |
|
|
|
# get metadata model instance |
|
|
|
# fetch model config |
|
|
|
model_instance, model_config = self._fetch_model_config(node_data.metadata_model_config) # type: ignore |
|
|
|
model_instance, model_config = self.get_model_config(metadata_model_config) |
|
|
|
# fetch prompt messages |
|
|
|
prompt_template = self._get_prompt_template( |
|
|
|
node_data=node_data, |
|
|
|
@@ -552,14 +554,7 @@ class KnowledgeRetrievalNode(LLMNode): |
|
|
|
variable_mapping[node_id + ".query"] = node_data.query_variable_selector |
|
|
|
return variable_mapping |
|
|
|
|
|
|
|
def _fetch_model_config(self, model: ModelConfig) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]: # type: ignore |
|
|
|
""" |
|
|
|
Fetch model config |
|
|
|
:param model: model |
|
|
|
:return: |
|
|
|
""" |
|
|
|
if model is None: |
|
|
|
raise ValueError("model is required") |
|
|
|
def get_model_config(self, model: ModelConfig) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]: |
|
|
|
model_name = model.name |
|
|
|
provider_name = model.provider |
|
|
|
|