| @@ -130,15 +130,14 @@ class GraphEngine: | |||
| yield GraphRunStartedEvent() | |||
| try: | |||
| stream_processor_cls: type[AnswerStreamProcessor | EndStreamProcessor] | |||
| if self.init_params.workflow_type == WorkflowType.CHAT: | |||
| stream_processor_cls = AnswerStreamProcessor | |||
| stream_processor = AnswerStreamProcessor( | |||
| graph=self.graph, variable_pool=self.graph_runtime_state.variable_pool | |||
| ) | |||
| else: | |||
| stream_processor_cls = EndStreamProcessor | |||
| stream_processor = stream_processor_cls( | |||
| graph=self.graph, variable_pool=self.graph_runtime_state.variable_pool | |||
| ) | |||
| stream_processor = EndStreamProcessor( | |||
| graph=self.graph, variable_pool=self.graph_runtime_state.variable_pool | |||
| ) | |||
| # run graph | |||
| generator = stream_processor.process(self._run(start_node_id=self.graph.root_node_id)) | |||
| @@ -149,10 +149,10 @@ class AnswerStreamGeneratorRouter: | |||
| source_node_id = edge.source_node_id | |||
| source_node_type = node_id_config_mapping[source_node_id].get("data", {}).get("type") | |||
| if source_node_type in { | |||
| NodeType.ANSWER.value, | |||
| NodeType.IF_ELSE.value, | |||
| NodeType.QUESTION_CLASSIFIER.value, | |||
| NodeType.ITERATION.value, | |||
| NodeType.ANSWER, | |||
| NodeType.IF_ELSE, | |||
| NodeType.QUESTION_CLASSIFIER, | |||
| NodeType.ITERATION, | |||
| }: | |||
| answer_dependencies[answer_node_id].append(source_node_id) | |||
| else: | |||
| @@ -22,7 +22,7 @@ class AnswerStreamProcessor(StreamProcessor): | |||
| super().__init__(graph, variable_pool) | |||
| self.generate_routes = graph.answer_stream_generate_routes | |||
| self.route_position = {} | |||
| for answer_node_id, route_chunks in self.generate_routes.answer_generate_route.items(): | |||
| for answer_node_id in self.generate_routes.answer_generate_route: | |||
| self.route_position[answer_node_id] = 0 | |||
| self.current_stream_chunk_generating_node_ids: dict[str, list[str]] = {} | |||
| @@ -41,7 +41,6 @@ class StreamProcessor(ABC): | |||
| continue | |||
| else: | |||
| unreachable_first_node_ids.append(edge.target_node_id) | |||
| unreachable_first_node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id)) | |||
| for node_id in unreachable_first_node_ids: | |||
| self._remove_node_ids_in_unreachable_branch(node_id, reachable_node_ids) | |||
| @@ -1,3 +1,4 @@ | |||
| from collections.abc import Sequence | |||
| from enum import Enum | |||
| from pydantic import BaseModel, Field | |||
| @@ -32,7 +33,7 @@ class VarGenerateRouteChunk(GenerateRouteChunk): | |||
| type: GenerateRouteChunk.ChunkType = GenerateRouteChunk.ChunkType.VAR | |||
| """generate route chunk type""" | |||
| value_selector: list[str] = Field(..., description="value selector") | |||
| value_selector: Sequence[str] = Field(..., description="value selector") | |||
| class TextGenerateRouteChunk(GenerateRouteChunk): | |||