### What problem does this PR solve? Feat: Handling abnormal anchor points of agent operators #3221 ### Type of change - [x] New Feature (non-breaking change which adds functionality)tags/v0.20.0
| import LLMLabel from '@/components/llm-select/llm-label'; | |||||
| import { IAgentNode } from '@/interfaces/database/flow'; | import { IAgentNode } from '@/interfaces/database/flow'; | ||||
| import { Handle, NodeProps, Position } from '@xyflow/react'; | import { Handle, NodeProps, Position } from '@xyflow/react'; | ||||
| import { get } from 'lodash'; | |||||
| import { memo, useMemo } from 'react'; | import { memo, useMemo } from 'react'; | ||||
| import { NodeHandleId } from '../../constant'; | |||||
| import { AgentExceptionMethod, NodeHandleId } from '../../constant'; | |||||
| import useGraphStore from '../../store'; | import useGraphStore from '../../store'; | ||||
| import { isBottomSubAgent } from '../../utils'; | import { isBottomSubAgent } from '../../utils'; | ||||
| import { CommonHandle } from './handle'; | import { CommonHandle } from './handle'; | ||||
| return !isBottomSubAgent(edges, id); | return !isBottomSubAgent(edges, id); | ||||
| }, [edges, id]); | }, [edges, id]); | ||||
| const exceptionMethod = useMemo(() => { | |||||
| return get(data, 'form.exception_method'); | |||||
| }, [data]); | |||||
| const isGotoMethod = useMemo(() => { | |||||
| return exceptionMethod === AgentExceptionMethod.Goto; | |||||
| }, [exceptionMethod]); | |||||
| return ( | return ( | ||||
| <ToolBar selected={selected} id={id} label={data.label}> | <ToolBar selected={selected} id={id} label={data.label}> | ||||
| <NodeWrapper selected={selected}> | <NodeWrapper selected={selected}> | ||||
| ></CommonHandle> | ></CommonHandle> | ||||
| </> | </> | ||||
| )} | )} | ||||
| <Handle | <Handle | ||||
| type="target" | type="target" | ||||
| position={Position.Top} | position={Position.Top} | ||||
| style={{ left: 20 }} | style={{ left: 20 }} | ||||
| ></Handle> | ></Handle> | ||||
| <NodeHeader id={id} name={data.name} label={data.label}></NodeHeader> | <NodeHeader id={id} name={data.name} label={data.label}></NodeHeader> | ||||
| <section className="flex flex-col gap-2"> | |||||
| <div className={'bg-background-card rounded-sm p-1'}> | |||||
| <LLMLabel value={get(data, 'form.llm_id')}></LLMLabel> | |||||
| </div> | |||||
| {(isGotoMethod || | |||||
| exceptionMethod === AgentExceptionMethod.Comment) && ( | |||||
| <div className="bg-background-card rounded-sm p-1 flex justify-between gap-2"> | |||||
| <span className="text-text-sub-title">Abnormal</span> | |||||
| <span className="truncate flex-1"> | |||||
| {isGotoMethod ? 'Exception branch' : 'Output default value'} | |||||
| </span> | |||||
| </div> | |||||
| )} | |||||
| </section> | |||||
| {isGotoMethod && ( | |||||
| <CommonHandle | |||||
| type="source" | |||||
| position={Position.Right} | |||||
| isConnectable={isConnectable} | |||||
| className="!bg-text-delete-red" | |||||
| style={{ ...RightHandleStyle, top: 94 }} | |||||
| nodeId={id} | |||||
| id={NodeHandleId.AgentException} | |||||
| isConnectableEnd={false} | |||||
| ></CommonHandle> | |||||
| )} | |||||
| </NodeWrapper> | </NodeWrapper> | ||||
| </ToolBar> | </ToolBar> | ||||
| ); | ); |
| max_rounds: 5, | max_rounds: 5, | ||||
| exception_method: null, | exception_method: null, | ||||
| exception_comment: '', | exception_comment: '', | ||||
| exception_goto: '', | |||||
| exception_goto: [], | |||||
| exception_default_value: '', | |||||
| tools: [], | tools: [], | ||||
| mcp: [], | mcp: [], | ||||
| outputs: { | outputs: { | ||||
| structured_output: { | |||||
| // topic: { | |||||
| // type: 'string', | |||||
| // description: | |||||
| // 'default:general. The category of the search.news is useful for retrieving real-time updates, particularly about politics, sports, and major current events covered by mainstream media sources. general is for broader, more general-purpose searches that may include a wide range of sources.', | |||||
| // enum: ['general', 'news'], | |||||
| // default: 'general', | |||||
| // }, | |||||
| }, | |||||
| // structured_output: { | |||||
| // topic: { | |||||
| // type: 'string', | |||||
| // description: | |||||
| // 'default:general. The category of the search.news is useful for retrieving real-time updates, particularly about politics, sports, and major current events covered by mainstream media sources. general is for broader, more general-purpose searches that may include a wide range of sources.', | |||||
| // enum: ['general', 'news'], | |||||
| // default: 'general', | |||||
| // }, | |||||
| // }, | |||||
| content: { | content: { | ||||
| type: 'string', | type: 'string', | ||||
| value: '', | value: '', | ||||
| Tool = 'tool', | Tool = 'tool', | ||||
| AgentTop = 'agentTop', | AgentTop = 'agentTop', | ||||
| AgentBottom = 'agentBottom', | AgentBottom = 'agentBottom', | ||||
| AgentException = 'agentException', | |||||
| } | } | ||||
| export enum VariableType { | export enum VariableType { |
| return buildCategorizeToOptions; | return buildCategorizeToOptions; | ||||
| }; | }; | ||||
| /** | |||||
| * dumped | |||||
| * @param nodeId | |||||
| * @returns | |||||
| */ | |||||
| export const useHandleFormSelectChange = (nodeId?: string) => { | |||||
| const { addEdge, deleteEdgeBySourceAndSourceHandle } = useGraphStore( | |||||
| (state) => state, | |||||
| ); | |||||
| const handleSelectChange = useCallback( | |||||
| (name?: string) => (value?: string) => { | |||||
| if (nodeId && name) { | |||||
| if (value) { | |||||
| addEdge({ | |||||
| source: nodeId, | |||||
| target: value, | |||||
| sourceHandle: name, | |||||
| targetHandle: null, | |||||
| }); | |||||
| } else { | |||||
| // clear selected value | |||||
| deleteEdgeBySourceAndSourceHandle({ | |||||
| source: nodeId, | |||||
| sourceHandle: name, | |||||
| }); | |||||
| } | |||||
| } | |||||
| }, | |||||
| [addEdge, nodeId, deleteEdgeBySourceAndSourceHandle], | |||||
| ); | |||||
| return { handleSelectChange }; | |||||
| }; | |||||
| export const useBuildSortOptions = () => { | export const useBuildSortOptions = () => { | ||||
| const { t } = useTranslate('flow'); | const { t } = useTranslate('flow'); | ||||
| import { useFindLlmByUuid } from '@/hooks/use-llm-request'; | import { useFindLlmByUuid } from '@/hooks/use-llm-request'; | ||||
| import { buildOptions } from '@/utils/form'; | import { buildOptions } from '@/utils/form'; | ||||
| import { zodResolver } from '@hookform/resolvers/zod'; | import { zodResolver } from '@hookform/resolvers/zod'; | ||||
| import { memo, useMemo } from 'react'; | |||||
| import { memo, useEffect, useMemo } from 'react'; | |||||
| import { useForm, useWatch } from 'react-hook-form'; | import { useForm, useWatch } from 'react-hook-form'; | ||||
| import { useTranslation } from 'react-i18next'; | import { useTranslation } from 'react-i18next'; | ||||
| import { z } from 'zod'; | import { z } from 'zod'; | ||||
| import { | import { | ||||
| AgentExceptionMethod, | AgentExceptionMethod, | ||||
| NodeHandleId, | |||||
| VariableType, | VariableType, | ||||
| initialAgentValues, | initialAgentValues, | ||||
| } from '../../constant'; | } from '../../constant'; | ||||
| import { INextOperatorForm } from '../../interface'; | import { INextOperatorForm } from '../../interface'; | ||||
| import useGraphStore from '../../store'; | import useGraphStore from '../../store'; | ||||
| import { isBottomSubAgent } from '../../utils'; | import { isBottomSubAgent } from '../../utils'; | ||||
| import { buildOutputList } from '../../utils/build-output-list'; | |||||
| import { DescriptionField } from '../components/description-field'; | import { DescriptionField } from '../components/description-field'; | ||||
| import { FormWrapper } from '../components/form-wrapper'; | |||||
| import { Output } from '../components/output'; | import { Output } from '../components/output'; | ||||
| import { PromptEditor } from '../components/prompt-editor'; | import { PromptEditor } from '../components/prompt-editor'; | ||||
| import { QueryVariable } from '../components/query-variable'; | import { QueryVariable } from '../components/query-variable'; | ||||
| max_rounds: z.coerce.number().optional(), | max_rounds: z.coerce.number().optional(), | ||||
| exception_method: z.string().nullable(), | exception_method: z.string().nullable(), | ||||
| exception_comment: z.string().optional(), | exception_comment: z.string().optional(), | ||||
| exception_goto: z.string().optional(), | |||||
| exception_goto: z.array(z.string()).optional(), | |||||
| exception_default_value: z.string().optional(), | |||||
| ...LargeModelFilterFormSchema, | ...LargeModelFilterFormSchema, | ||||
| }); | }); | ||||
| const outputList = buildOutputList(initialAgentValues.outputs); | |||||
| function AgentForm({ node }: INextOperatorForm) { | function AgentForm({ node }: INextOperatorForm) { | ||||
| const { t } = useTranslation(); | const { t } = useTranslation(); | ||||
| const { edges } = useGraphStore((state) => state); | |||||
| const { edges, deleteEdgesBySourceAndSourceHandle } = useGraphStore( | |||||
| (state) => state, | |||||
| ); | |||||
| const defaultValues = useValues(node); | const defaultValues = useValues(node); | ||||
| return isBottomSubAgent(edges, node?.id); | return isBottomSubAgent(edges, node?.id); | ||||
| }, [edges, node?.id]); | }, [edges, node?.id]); | ||||
| const outputList = useMemo(() => { | |||||
| return [ | |||||
| { title: 'content', type: initialAgentValues.outputs.content.type }, | |||||
| ]; | |||||
| }, []); | |||||
| const form = useForm<z.infer<typeof FormSchema>>({ | const form = useForm<z.infer<typeof FormSchema>>({ | ||||
| defaultValues: defaultValues, | defaultValues: defaultValues, | ||||
| resolver: zodResolver(FormSchema), | resolver: zodResolver(FormSchema), | ||||
| const findLlmByUuid = useFindLlmByUuid(); | const findLlmByUuid = useFindLlmByUuid(); | ||||
| const exceptionMethod = useWatch({ | |||||
| control: form.control, | |||||
| name: 'exception_method', | |||||
| }); | |||||
| useEffect(() => { | |||||
| if (exceptionMethod !== AgentExceptionMethod.Goto) { | |||||
| if (node?.id) { | |||||
| deleteEdgesBySourceAndSourceHandle( | |||||
| node?.id, | |||||
| NodeHandleId.AgentException, | |||||
| ); | |||||
| } | |||||
| } | |||||
| }, [deleteEdgesBySourceAndSourceHandle, exceptionMethod, node?.id]); | |||||
| useWatchFormChange(node?.id, form); | useWatchFormChange(node?.id, form); | ||||
| return ( | return ( | ||||
| <Form {...form}> | <Form {...form}> | ||||
| <form | |||||
| className="space-y-6 p-4" | |||||
| onSubmit={(e) => { | |||||
| e.preventDefault(); | |||||
| }} | |||||
| > | |||||
| <FormWrapper> | |||||
| <FormContainer> | <FormContainer> | ||||
| {isSubAgent && <DescriptionField></DescriptionField>} | {isSubAgent && <DescriptionField></DescriptionField>} | ||||
| <LargeModelFormField></LargeModelFormField> | <LargeModelFormField></LargeModelFormField> | ||||
| </FormItem> | </FormItem> | ||||
| )} | )} | ||||
| /> | /> | ||||
| <FormField | |||||
| control={form.control} | |||||
| name={`exception_default_value`} | |||||
| render={({ field }) => ( | |||||
| <FormItem className="flex-1"> | |||||
| <FormLabel>Exception default value</FormLabel> | |||||
| <FormControl> | |||||
| <Input {...field} /> | |||||
| </FormControl> | |||||
| </FormItem> | |||||
| )} | |||||
| /> | |||||
| <FormField | <FormField | ||||
| control={form.control} | control={form.control} | ||||
| name={`exception_comment`} | name={`exception_comment`} | ||||
| </FormItem> | </FormItem> | ||||
| )} | )} | ||||
| /> | /> | ||||
| <QueryVariable | |||||
| name="exception_goto" | |||||
| label="Exception goto" | |||||
| type={VariableType.File} | |||||
| ></QueryVariable> | |||||
| </FormContainer> | </FormContainer> | ||||
| </Collapse> | </Collapse> | ||||
| <Output list={outputList}></Output> | <Output list={outputList}></Output> | ||||
| </form> | |||||
| </FormWrapper> | |||||
| </Form> | </Form> | ||||
| ); | ); | ||||
| } | } |
| const FormSchema = useCreateCategorizeFormSchema(); | const FormSchema = useCreateCategorizeFormSchema(); | ||||
| const deleteCategorizeCaseEdges = useGraphStore( | const deleteCategorizeCaseEdges = useGraphStore( | ||||
| (state) => state.deleteCategorizeCaseEdges, | |||||
| (state) => state.deleteEdgesBySourceAndSourceHandle, | |||||
| ); | ); | ||||
| const form = useFormContext<z.infer<typeof FormSchema>>(); | const form = useFormContext<z.infer<typeof FormSchema>>(); | ||||
| const { t } = useTranslate('flow'); | const { t } = useTranslate('flow'); |
| FormLabel, | FormLabel, | ||||
| FormMessage, | FormMessage, | ||||
| } from '@/components/ui/form'; | } from '@/components/ui/form'; | ||||
| import { Input } from '@/components/ui/input'; | |||||
| import { RAGFlowSelect } from '@/components/ui/select'; | import { RAGFlowSelect } from '@/components/ui/select'; | ||||
| import { buildOptions } from '@/utils/form'; | import { buildOptions } from '@/utils/form'; | ||||
| import { zodResolver } from '@hookform/resolvers/zod'; | import { zodResolver } from '@hookform/resolvers/zod'; | ||||
| import { ApiKeyField } from '../components/api-key-field'; | import { ApiKeyField } from '../components/api-key-field'; | ||||
| import { FormWrapper } from '../components/form-wrapper'; | import { FormWrapper } from '../components/form-wrapper'; | ||||
| import { Output } from '../components/output'; | import { Output } from '../components/output'; | ||||
| import { PromptEditor } from '../components/prompt-editor'; | |||||
| import { TavilyFormSchema } from '../tavily-form'; | import { TavilyFormSchema } from '../tavily-form'; | ||||
| const outputList = buildOutputList(initialTavilyExtractValues.outputs); | const outputList = buildOutputList(initialTavilyExtractValues.outputs); | ||||
| <FormItem> | <FormItem> | ||||
| <FormLabel>URL</FormLabel> | <FormLabel>URL</FormLabel> | ||||
| <FormControl> | <FormControl> | ||||
| <Input {...field} /> | |||||
| <PromptEditor | |||||
| {...field} | |||||
| multiLine={false} | |||||
| showToolbar={false} | |||||
| ></PromptEditor> | |||||
| </FormControl> | </FormControl> | ||||
| <FormMessage /> | <FormMessage /> | ||||
| </FormItem> | </FormItem> |
| import { Switch } from '@/components/ui/switch'; | import { Switch } from '@/components/ui/switch'; | ||||
| import { buildOptions } from '@/utils/form'; | import { buildOptions } from '@/utils/form'; | ||||
| import { zodResolver } from '@hookform/resolvers/zod'; | import { zodResolver } from '@hookform/resolvers/zod'; | ||||
| import { memo, useMemo } from 'react'; | |||||
| import { memo } from 'react'; | |||||
| import { useForm } from 'react-hook-form'; | import { useForm } from 'react-hook-form'; | ||||
| import { z } from 'zod'; | import { z } from 'zod'; | ||||
| import { | import { | ||||
| initialTavilyValues, | initialTavilyValues, | ||||
| } from '../../constant'; | } from '../../constant'; | ||||
| import { INextOperatorForm } from '../../interface'; | import { INextOperatorForm } from '../../interface'; | ||||
| import { buildOutputList } from '../../utils/build-output-list'; | |||||
| import { ApiKeyField } from '../components/api-key-field'; | import { ApiKeyField } from '../components/api-key-field'; | ||||
| import { FormWrapper } from '../components/form-wrapper'; | import { FormWrapper } from '../components/form-wrapper'; | ||||
| import { Output, OutputType } from '../components/output'; | |||||
| import { Output } from '../components/output'; | |||||
| import { QueryVariable } from '../components/query-variable'; | import { QueryVariable } from '../components/query-variable'; | ||||
| import { DynamicDomain } from './dynamic-domain'; | import { DynamicDomain } from './dynamic-domain'; | ||||
| import { useValues } from './use-values'; | import { useValues } from './use-values'; | ||||
| api_key: z.string(), | api_key: z.string(), | ||||
| }; | }; | ||||
| const outputList = buildOutputList(initialTavilyValues.outputs); | |||||
| function TavilyForm({ node }: INextOperatorForm) { | function TavilyForm({ node }: INextOperatorForm) { | ||||
| const values = useValues(node); | const values = useValues(node); | ||||
| resolver: zodResolver(FormSchema), | resolver: zodResolver(FormSchema), | ||||
| }); | }); | ||||
| const outputList = useMemo(() => { | |||||
| return Object.entries(initialTavilyValues.outputs).reduce<OutputType[]>( | |||||
| (pre, [key, val]) => { | |||||
| pre.push({ title: key, type: val.type }); | |||||
| return pre; | |||||
| }, | |||||
| [], | |||||
| ); | |||||
| }, []); | |||||
| useWatchFormChange(node?.id, form); | useWatchFormChange(node?.id, form); | ||||
| return ( | return ( |
| deleteAgentDownstreamNodesById: (id: string) => void; | deleteAgentDownstreamNodesById: (id: string) => void; | ||||
| deleteAgentToolNodeById: (id: string) => void; | deleteAgentToolNodeById: (id: string) => void; | ||||
| deleteIterationNodeById: (id: string) => void; | deleteIterationNodeById: (id: string) => void; | ||||
| deleteEdgeBySourceAndSourceHandle: (connection: Partial<Connection>) => void; | |||||
| findNodeByName: (operatorName: Operator) => RAGFlowNodeType | undefined; | findNodeByName: (operatorName: Operator) => RAGFlowNodeType | undefined; | ||||
| updateMutableNodeFormItem: (id: string, field: string, value: any) => void; | updateMutableNodeFormItem: (id: string, field: string, value: any) => void; | ||||
| getOperatorTypeFromId: (id?: string | null) => string | undefined; | getOperatorTypeFromId: (id?: string | null) => string | undefined; | ||||
| setClickedNodeId: (id?: string) => void; | setClickedNodeId: (id?: string) => void; | ||||
| setClickedToolId: (id?: string) => void; | setClickedToolId: (id?: string) => void; | ||||
| findUpstreamNodeById: (id?: string | null) => RAGFlowNodeType | undefined; | findUpstreamNodeById: (id?: string | null) => RAGFlowNodeType | undefined; | ||||
| deleteCategorizeCaseEdges: (source: string, sourceHandle: string) => void; // Deleting a condition of a classification operator will delete the related edge | |||||
| deleteEdgesBySourceAndSourceHandle: ( | |||||
| source: string, | |||||
| sourceHandle: string, | |||||
| ) => void; // Deleting a condition of a classification operator will delete the related edge | |||||
| findAgentToolNodeById: (id: string | null) => string | undefined; | findAgentToolNodeById: (id: string | null) => string | undefined; | ||||
| selectNodeIds: (nodeIds: string[]) => void; | selectNodeIds: (nodeIds: string[]) => void; | ||||
| }; | }; | ||||
| edges: edges.filter((edge) => edge.id !== id), | edges: edges.filter((edge) => edge.id !== id), | ||||
| }); | }); | ||||
| }, | }, | ||||
| deleteEdgeBySourceAndSourceHandle: ({ | |||||
| source, | |||||
| sourceHandle, | |||||
| }: Partial<Connection>) => { | |||||
| const { edges } = get(); | |||||
| const nextEdges = edges.filter( | |||||
| (edge) => | |||||
| edge.source !== source || edge.sourceHandle !== sourceHandle, | |||||
| ); | |||||
| set({ | |||||
| edges: nextEdges, | |||||
| }); | |||||
| }, | |||||
| deleteNodeById: (id: string) => { | deleteNodeById: (id: string) => { | ||||
| const { | const { | ||||
| nodes, | nodes, | ||||
| const edge = edges.find((x) => x.target === id); | const edge = edges.find((x) => x.target === id); | ||||
| return getNode(edge?.source); | return getNode(edge?.source); | ||||
| }, | }, | ||||
| deleteCategorizeCaseEdges: (source, sourceHandle) => { | |||||
| deleteEdgesBySourceAndSourceHandle: (source, sourceHandle) => { | |||||
| const { edges, setEdges } = get(); | const { edges, setEdges } = get(); | ||||
| setEdges( | setEdges( | ||||
| edges.filter( | edges.filter( |
| import fs from 'fs'; | |||||
| import path from 'path'; | |||||
| import customer_service from '../../../../graph/test/dsl_examples/customer_service.json'; | |||||
| import headhunter_zh from '../../../../graph/test/dsl_examples/headhunter_zh.json'; | |||||
| import interpreter from '../../../../graph/test/dsl_examples/interpreter.json'; | |||||
| import retrievalRelevantRewriteAndGenerate from '../../../../graph/test/dsl_examples/retrieval_relevant_rewrite_and_generate.json'; | |||||
| import { dsl } from './mock'; | |||||
| import { buildNodesAndEdgesFromDSLComponents } from './utils'; | |||||
| test('buildNodesAndEdgesFromDSLComponents', () => { | |||||
| const { edges, nodes } = buildNodesAndEdgesFromDSLComponents(dsl.components); | |||||
| expect(nodes.length).toEqual(4); | |||||
| expect(edges.length).toEqual(4); | |||||
| expect(edges).toEqual( | |||||
| expect.arrayContaining([ | |||||
| expect.objectContaining({ | |||||
| source: 'begin', | |||||
| target: 'Answer:China', | |||||
| }), | |||||
| expect.objectContaining({ | |||||
| source: 'Answer:China', | |||||
| target: 'Retrieval:China', | |||||
| }), | |||||
| expect.objectContaining({ | |||||
| source: 'Retrieval:China', | |||||
| target: 'Generate:China', | |||||
| }), | |||||
| expect.objectContaining({ | |||||
| source: 'Generate:China', | |||||
| target: 'Answer:China', | |||||
| }), | |||||
| ]), | |||||
| ); | |||||
| }); | |||||
| test('build nodes and edges from headhunter_zh dsl', () => { | |||||
| const { edges, nodes } = buildNodesAndEdgesFromDSLComponents( | |||||
| headhunter_zh.components, | |||||
| ); | |||||
| console.info('node length', nodes.length); | |||||
| console.info('edge length', edges.length); | |||||
| try { | |||||
| fs.writeFileSync( | |||||
| path.join(__dirname, 'headhunter_zh.json'), | |||||
| JSON.stringify({ edges, nodes }, null, 4), | |||||
| ); | |||||
| console.log('JSON data is saved.'); | |||||
| } catch (error) { | |||||
| console.warn(error); | |||||
| } | |||||
| expect(nodes.length).toEqual(12); | |||||
| }); | |||||
| test('build nodes and edges from customer_service dsl', () => { | |||||
| const { edges, nodes } = buildNodesAndEdgesFromDSLComponents( | |||||
| customer_service.components, | |||||
| ); | |||||
| console.info('node length', nodes.length); | |||||
| console.info('edge length', edges.length); | |||||
| try { | |||||
| fs.writeFileSync( | |||||
| path.join(__dirname, 'customer_service.json'), | |||||
| JSON.stringify({ edges, nodes }, null, 4), | |||||
| ); | |||||
| console.log('JSON data is saved.'); | |||||
| } catch (error) { | |||||
| console.warn(error); | |||||
| } | |||||
| expect(nodes.length).toEqual(12); | |||||
| }); | |||||
| test('build nodes and edges from interpreter dsl', () => { | |||||
| const { edges, nodes } = buildNodesAndEdgesFromDSLComponents( | |||||
| interpreter.components, | |||||
| ); | |||||
| console.info('node length', nodes.length); | |||||
| console.info('edge length', edges.length); | |||||
| try { | |||||
| fs.writeFileSync( | |||||
| path.join(__dirname, 'interpreter.json'), | |||||
| JSON.stringify({ edges, nodes }, null, 4), | |||||
| ); | |||||
| console.log('JSON data is saved.'); | |||||
| } catch (error) { | |||||
| console.warn(error); | |||||
| } | |||||
| expect(nodes.length).toEqual(12); | |||||
| }); | |||||
| test('build nodes and edges from chat bot dsl', () => { | |||||
| const { edges, nodes } = buildNodesAndEdgesFromDSLComponents( | |||||
| retrievalRelevantRewriteAndGenerate.components, | |||||
| ); | |||||
| try { | |||||
| fs.writeFileSync( | |||||
| path.join(__dirname, 'retrieval_relevant_rewrite_and_generate.json'), | |||||
| JSON.stringify({ edges, nodes }, null, 4), | |||||
| ); | |||||
| console.log('JSON data is saved.'); | |||||
| } catch (error) { | |||||
| console.warn(error); | |||||
| } | |||||
| expect(nodes.length).toEqual(12); | |||||
| }); |
| } from '@/interfaces/database/agent'; | } from '@/interfaces/database/agent'; | ||||
| import { DSLComponents, RAGFlowNodeType } from '@/interfaces/database/flow'; | import { DSLComponents, RAGFlowNodeType } from '@/interfaces/database/flow'; | ||||
| import { removeUselessFieldsFromValues } from '@/utils/form'; | import { removeUselessFieldsFromValues } from '@/utils/form'; | ||||
| import { Edge, Node, Position, XYPosition } from '@xyflow/react'; | |||||
| import { Edge, Node, XYPosition } from '@xyflow/react'; | |||||
| import { FormInstance, FormListFieldData } from 'antd'; | import { FormInstance, FormListFieldData } from 'antd'; | ||||
| import { humanId } from 'human-id'; | import { humanId } from 'human-id'; | ||||
| import { curry, get, intersectionWith, isEqual, omit, sample } from 'lodash'; | import { curry, get, intersectionWith, isEqual, omit, sample } from 'lodash'; | ||||
| import pipe from 'lodash/fp/pipe'; | import pipe from 'lodash/fp/pipe'; | ||||
| import isObject from 'lodash/isObject'; | import isObject from 'lodash/isObject'; | ||||
| import { v4 as uuidv4 } from 'uuid'; | |||||
| import { | import { | ||||
| CategorizeAnchorPointPositions, | CategorizeAnchorPointPositions, | ||||
| NoDebugOperatorsList, | NoDebugOperatorsList, | ||||
| NodeHandleId, | NodeHandleId, | ||||
| NodeMap, | |||||
| Operator, | Operator, | ||||
| } from './constant'; | } from './constant'; | ||||
| import { BeginQuery, IPosition } from './interface'; | import { BeginQuery, IPosition } from './interface'; | ||||
| const buildEdges = ( | |||||
| operatorIds: string[], | |||||
| currentId: string, | |||||
| allEdges: Edge[], | |||||
| isUpstream = false, | |||||
| componentName: string, | |||||
| nodeParams: Record<string, unknown>, | |||||
| ) => { | |||||
| operatorIds.forEach((cur) => { | |||||
| const source = isUpstream ? cur : currentId; | |||||
| const target = isUpstream ? currentId : cur; | |||||
| if (!allEdges.some((e) => e.source === source && e.target === target)) { | |||||
| const edge: Edge = { | |||||
| id: uuidv4(), | |||||
| label: '', | |||||
| // type: 'step', | |||||
| source: source, | |||||
| target: target, | |||||
| // markerEnd: { | |||||
| // type: MarkerType.ArrowClosed, | |||||
| // color: 'rgb(157 149 225)', | |||||
| // width: 20, | |||||
| // height: 20, | |||||
| // }, | |||||
| }; | |||||
| if (componentName === Operator.Categorize && !isUpstream) { | |||||
| const categoryDescription = | |||||
| nodeParams.category_description as ICategorizeItemResult; | |||||
| const name = Object.keys(categoryDescription).find( | |||||
| (x) => categoryDescription[x].to === target, | |||||
| ); | |||||
| if (name) { | |||||
| edge.sourceHandle = name; | |||||
| } | |||||
| } | |||||
| allEdges.push(edge); | |||||
| } | |||||
| }); | |||||
| }; | |||||
| export const buildNodesAndEdgesFromDSLComponents = (data: DSLComponents) => { | |||||
| const nodes: Node[] = []; | |||||
| let edges: Edge[] = []; | |||||
| Object.entries(data).forEach(([key, value]) => { | |||||
| const downstream = [...value.downstream]; | |||||
| const upstream = [...value.upstream]; | |||||
| const { component_name: componentName, params } = value.obj; | |||||
| nodes.push({ | |||||
| id: key, | |||||
| type: NodeMap[value.obj.component_name as Operator] || 'ragNode', | |||||
| position: { x: 0, y: 0 }, | |||||
| data: { | |||||
| label: componentName, | |||||
| name: humanId(), | |||||
| form: params, | |||||
| }, | |||||
| sourcePosition: Position.Left, | |||||
| targetPosition: Position.Right, | |||||
| }); | |||||
| buildEdges(upstream, key, edges, true, componentName, params); | |||||
| buildEdges(downstream, key, edges, false, componentName, params); | |||||
| }); | |||||
| function buildAgentExceptionGoto(edges: Edge[], nodeId: string) { | |||||
| const exceptionEdges = edges.filter( | |||||
| (x) => | |||||
| x.source === nodeId && x.sourceHandle === NodeHandleId.AgentException, | |||||
| ); | |||||
| return { nodes, edges }; | |||||
| }; | |||||
| return exceptionEdges.map((x) => x.target); | |||||
| } | |||||
| const buildComponentDownstreamOrUpstream = ( | const buildComponentDownstreamOrUpstream = ( | ||||
| edges: Edge[], | edges: Edge[], | ||||
| const node = nodes.find((x) => x.id === nodeId); | const node = nodes.find((x) => x.id === nodeId); | ||||
| let isNotUpstreamTool = true; | let isNotUpstreamTool = true; | ||||
| let isNotUpstreamAgent = true; | let isNotUpstreamAgent = true; | ||||
| let isNotExceptionGoto = true; | |||||
| if (isBuildDownstream && node?.data.label === Operator.Agent) { | if (isBuildDownstream && node?.data.label === Operator.Agent) { | ||||
| isNotExceptionGoto = y.sourceHandle !== NodeHandleId.AgentException; | |||||
| // Exclude the tool operator downstream of the agent operator | // Exclude the tool operator downstream of the agent operator | ||||
| isNotUpstreamTool = !y.target.startsWith(Operator.Tool); | isNotUpstreamTool = !y.target.startsWith(Operator.Tool); | ||||
| // Exclude the agent operator downstream of the agent operator | // Exclude the agent operator downstream of the agent operator | ||||
| return ( | return ( | ||||
| y[isBuildDownstream ? 'source' : 'target'] === nodeId && | y[isBuildDownstream ? 'source' : 'target'] === nodeId && | ||||
| isNotUpstreamTool && | isNotUpstreamTool && | ||||
| isNotUpstreamAgent | |||||
| isNotUpstreamAgent && | |||||
| isNotExceptionGoto | |||||
| ); | ); | ||||
| }) | }) | ||||
| .map((y) => y[isBuildDownstream ? 'target' : 'source']); | .map((y) => y[isBuildDownstream ? 'target' : 'source']); | ||||
| switch (operatorName) { | switch (operatorName) { | ||||
| case Operator.Agent: { | case Operator.Agent: { | ||||
| const { params: formData } = buildAgentTools(edges, nodes, id); | const { params: formData } = buildAgentTools(edges, nodes, id); | ||||
| params = formData; | |||||
| params = { | |||||
| ...formData, | |||||
| exception_goto: buildAgentExceptionGoto(edges, id), | |||||
| }; | |||||
| break; | break; | ||||
| } | } | ||||
| case Operator.Categorize: | case Operator.Categorize: | ||||
| if (cur?.name) { | if (cur?.name) { | ||||
| pre[cur.name] = { | pre[cur.name] = { | ||||
| ...omit(cur, 'name', 'examples'), | ...omit(cur, 'name', 'examples'), | ||||
| examples: convertToStringArray(cur.examples), | |||||
| examples: convertToStringArray(cur.examples) as string[], | |||||
| }; | }; | ||||
| } | } | ||||
| return pre; | return pre; |