瀏覽代碼

Feat: Connect conditional operators to other operators #3221 (#8231)

### What problem does this PR solve?

Feat: Connect conditional operators to other operators #3221

### Type of change


- [x] New Feature (non-breaking change which adds functionality)
tags/v0.19.1
balibabu 4 月之前
父節點
當前提交
a9d9215547
沒有連結到貢獻者的電子郵件帳戶。

+ 18
- 0
web/src/interfaces/database/agent.ts 查看文件

@@ -9,3 +9,21 @@ export type ICategorizeItemResult = Record<
string,
Omit<ICategorizeItem, 'name' | 'examples'> & { examples: string[] }
>;

export interface ISwitchCondition {
items: ISwitchItem[];
logical_operator: string;
to: string[];
}

export interface ISwitchItem {
cpn_id: string;
operator: string;
value: string;
}

export interface ISwitchForm {
conditions: ISwitchCondition[];
end_cpn_ids: string[];
no: string;
}

+ 1
- 1
web/src/interfaces/database/flow.ts 查看文件

@@ -92,7 +92,7 @@ export interface IRelevantForm extends IGenerateForm {
export interface ISwitchCondition {
items: ISwitchItem[];
logical_operator: string;
to: string;
to: string[] | string;
}

export interface ISwitchItem {

+ 0
- 3
web/src/pages/agent/canvas/index.tsx 查看文件

@@ -17,7 +17,6 @@ import {
useHandleDrop,
useSelectCanvasData,
useValidateConnection,
useWatchNodeFormDataChange,
} from '../hooks';
import { useAddNode } from '../hooks/use-add-node';
import { useBeforeDelete } from '../hooks/use-before-delete';
@@ -120,8 +119,6 @@ function AgentCanvas({ drawerVisible, hideDrawer }: IProps) {

const { handleBeforeDelete } = useBeforeDelete();

useWatchNodeFormDataChange();

const { addCanvasNode } = useAddNode(reactFlowInstance);

useEffect(() => {

+ 30
- 25
web/src/pages/agent/canvas/node/switch-node.tsx 查看文件

@@ -1,9 +1,12 @@
import { IconFont } from '@/components/icon-font';
import { useTheme } from '@/components/theme-provider';
import { Card, CardContent } from '@/components/ui/card';
import { ISwitchCondition, ISwitchNode } from '@/interfaces/database/flow';
import { Handle, NodeProps, Position } from '@xyflow/react';
import { Divider, Flex } from 'antd';
import { Flex } from 'antd';
import classNames from 'classnames';
import { memo } from 'react';
import { memo, useCallback } from 'react';
import { SwitchOperatorOptions } from '../../constant';
import { useGetComponentLabelByValue } from '../../hooks/use-get-begin-query';
import { RightHandleStyle } from './handle-icon';
import { useBuildSwitchHandlePositions } from './hooks';
@@ -29,29 +32,28 @@ const ConditionBlock = ({
}) => {
const items = condition?.items ?? [];
const getLabel = useGetComponentLabelByValue(nodeId);

const renderOperatorIcon = useCallback((operator?: string) => {
const name = SwitchOperatorOptions.find((x) => x.value === operator)?.icon;
return <IconFont name={name!}></IconFont>;
}, []);

return (
<Flex vertical className={styles.conditionBlock}>
{items.map((x, idx) => (
<div key={idx}>
<Flex>
<div
className={classNames(styles.conditionLine, styles.conditionKey)}
>
{getLabel(x?.cpn_id)}
</div>
<span className={styles.conditionOperator}>{x?.operator}</span>
<Flex flex={1} className={styles.conditionLine}>
{x?.value}
</Flex>
</Flex>
{idx + 1 < items.length && (
<Divider orientationMargin="0" className={styles.zeroDivider}>
{condition?.logical_operator}
</Divider>
)}
</div>
))}
</Flex>
<Card>
<CardContent className="space-y-1 p-1">
{items.map((x, idx) => (
<div key={idx}>
<section className="flex justify-between gap-2 items-center text-xs">
<div className="flex-1 truncate text-background-checked">
{getLabel(x?.cpn_id)}
</div>
<span>{renderOperatorIcon(x?.operator)}</span>
<div className="flex-1 truncate">{x?.value}</div>
</section>
</div>
))}
</CardContent>
</Card>
);
};

@@ -87,7 +89,10 @@ function InnerSwitchNode({ id, data, selected }: NodeProps<ISwitchNode>) {
<div key={idx}>
<Flex vertical>
<Flex justify={'space-between'}>
<span>{idx < positions.length - 1 && position.text}</span>
<span className="text-text-sub-title text-xs translate-y-2">
{idx < positions.length - 1 &&
position.condition?.logical_operator?.toUpperCase()}
</span>
<span>{getConditionKey(idx, positions.length)}</span>
</Flex>
{position.condition && (

+ 33
- 20
web/src/pages/agent/constant.tsx 查看文件

@@ -129,6 +129,8 @@ export enum Operator {
Agent = 'Agent',
}

export const SwitchLogicOperatorOptions = ['and', 'or'];

export const CommonOperatorList = Object.values(Operator).filter(
(x) => x !== Operator.Note,
);
@@ -445,6 +447,23 @@ export const componentMenuList = [
},
];

export const SwitchOperatorOptions = [
{ value: '=', label: 'equal', icon: 'equal' },
{ value: '≠', label: 'notEqual', icon: 'not-equals' },
{ value: '>', label: 'gt', icon: 'Less' },
{ value: '≥', label: 'ge', icon: 'Greater-or-equal' },
{ value: '<', label: 'lt', icon: 'Less' },
{ value: '≤', label: 'le', icon: 'less-or-equal' },
{ value: 'contains', label: 'contains', icon: 'Contains' },
{ value: 'not contains', label: 'notContains', icon: 'not-contains' },
{ value: 'start with', label: 'startWith', icon: 'list-start' },
{ value: 'end with', label: 'endWith', icon: 'list-end' },
{ value: 'empty', label: 'empty', icon: 'circle' },
{ value: 'not empty', label: 'notEmpty', icon: 'circle-slash-2' },
];

export const SwitchElseTo = 'end_cpn_ids';

const initialQueryBaseValues = {
query: [],
};
@@ -616,7 +635,20 @@ export const initialExeSqlValues = {
...initialQueryBaseValues,
};

export const initialSwitchValues = { conditions: [] };
export const initialSwitchValues = {
conditions: [
{
logical_operator: SwitchLogicOperatorOptions[0],
items: [
{
operator: SwitchOperatorOptions[0].value,
},
],
to: [],
},
],
[SwitchElseTo]: [],
};

export const initialWenCaiValues = {
top_n: 20,
@@ -3000,25 +3032,6 @@ export const ExeSQLOptions = ['mysql', 'postgresql', 'mariadb', 'mssql'].map(
}),
);

export const SwitchElseTo = 'end_cpn_id';

export const SwitchOperatorOptions = [
{ value: '=', label: 'equal', icon: 'equal' },
{ value: '≠', label: 'notEqual', icon: 'not-equals' },
{ value: '>', label: 'gt', icon: 'Less' },
{ value: '≥', label: 'ge', icon: 'Greater-or-equal' },
{ value: '<', label: 'lt', icon: 'Less' },
{ value: '≤', label: 'le', icon: 'less-or-equal' },
{ value: 'contains', label: 'contains', icon: 'Contains' },
{ value: 'not contains', label: 'notContains', icon: 'not-contains' },
{ value: 'start with', label: 'startWith', icon: 'list-start' },
{ value: 'end with', label: 'endWith', icon: 'list-end' },
// { value: 'empty', label: 'empty', icon: '' },
// { value: 'not empty', label: 'notEmpty', icon: '' },
];

export const SwitchLogicOperatorOptions = ['and', 'or'];

export const WenCaiQueryTypeOptions = [
'stock',
'zhishu',

+ 22
- 12
web/src/pages/agent/form/switch-form/index.tsx 查看文件

@@ -12,7 +12,7 @@ import {
import { RAGFlowSelect } from '@/components/ui/select';
import { Separator } from '@/components/ui/separator';
import { Textarea } from '@/components/ui/textarea';
import { ISwitchForm } from '@/interfaces/database/flow';
import { ISwitchForm } from '@/interfaces/database/agent';
import { cn } from '@/lib/utils';
import { zodResolver } from '@hookform/resolvers/zod';
import { X } from 'lucide-react';
@@ -27,6 +27,7 @@ import {
} from '../../constant';
import { useBuildFormSelectOptions } from '../../form-hooks';
import { useBuildComponentIdAndBeginOptions } from '../../hooks/use-get-begin-query';
import { useWatchFormChange } from '../../hooks/use-watch-form-change';
import { IOperatorForm } from '../../interface';
import { useValues } from './use-values';

@@ -40,20 +41,27 @@ type ConditionCardsProps = {
parentLength: number;
} & IOperatorForm;

const OperatorIcon = function OperatorIcon({
icon,
value,
}: Omit<(typeof SwitchOperatorOptions)[0], 'label'>) {
return (
<IconFont
name={icon}
className={cn('size-4', {
'rotate-180': value === '>',
})}
></IconFont>
);
};

function useBuildSwitchOperatorOptions() {
const { t } = useTranslation();

const switchOperatorOptions = useMemo(() => {
return SwitchOperatorOptions.map((x) => ({
value: x.value,
icon: (
<IconFont
name={x.icon}
className={cn('size-4', {
'rotate-180': x.value === '>',
})}
></IconFont>
),
icon: <OperatorIcon icon={x.icon} value={x.value}></OperatorIcon>,
label: t(`flow.switchOperatorOptions.${x.label}`),
}));
}, [t]);
@@ -174,7 +182,7 @@ function ConditionCards({
className="mt-6"
onClick={() => append({ operator: switchOperatorOptions[0].value })}
>
add
Add
</BlockButton>
</div>
</section>
@@ -183,7 +191,7 @@ function ConditionCards({

const SwitchForm = ({ node }: IOperatorForm) => {
const { t } = useTranslation();
const values = useValues();
const values = useValues(node);
const switchOperatorOptions = useBuildSwitchOperatorOptions();

const FormSchema = z.object({
@@ -234,6 +242,8 @@ const SwitchForm = ({ node }: IOperatorForm) => {
}));
}, [t]);

useWatchFormChange(node?.id, form);

return (
<Form {...form}>
<form
@@ -289,7 +299,7 @@ const SwitchForm = ({ node }: IOperatorForm) => {
})
}
>
add
Add
</BlockButton>
</form>
</Form>

+ 2
- 5
web/src/pages/agent/form/switch-form/use-values.ts 查看文件

@@ -1,16 +1,13 @@
import { RAGFlowNodeType } from '@/interfaces/database/flow';
import { isEmpty } from 'lodash';
import { useMemo } from 'react';

const defaultValues = {
conditions: [],
};
import { initialSwitchValues } from '../../constant';

export function useValues(node?: RAGFlowNodeType) {
const values = useMemo(() => {
const formData = node?.data?.form;
if (isEmpty(formData)) {
return defaultValues;
return initialSwitchValues;
}

return formData;

+ 4
- 5
web/src/pages/agent/hooks.tsx 查看文件

@@ -15,10 +15,10 @@ import React, {
// import { shallow } from 'zustand/shallow';
import { settledModelVariableMap } from '@/constants/knowledge';
import { useFetchModelId } from '@/hooks/logic-hooks';
import { ISwitchForm } from '@/interfaces/database/agent';
import {
ICategorizeForm,
IRelevantForm,
ISwitchForm,
RAGFlowNodeType,
} from '@/interfaces/database/flow';
import { message } from 'antd';
@@ -543,9 +543,9 @@ export const useWatchNodeFormDataChange = () => {
case Operator.Categorize:
buildCategorizeEdgesByFormData(node.id, form as ICategorizeForm);
break;
case Operator.Switch:
buildSwitchEdgesByFormData(node.id, form as ISwitchForm);
break;
// case Operator.Switch:
// buildSwitchEdgesByFormData(node.id, form as ISwitchForm);
// break;
default:
break;
}
@@ -555,7 +555,6 @@ export const useWatchNodeFormDataChange = () => {
buildCategorizeEdgesByFormData,
getNode,
buildRelevantEdgesByFormData,
buildSwitchEdgesByFormData,
]);
};


+ 1
- 0
web/src/pages/agent/hooks/use-add-node.ts 查看文件

@@ -224,6 +224,7 @@ export function useAddNode(reactFlowInstance?: ReactFlowInstance<any, any>) {
[
addEdge,
addNode,
edges,
getNode,
getNodeName,
initializeOperatorParams,

+ 18
- 18
web/src/pages/agent/hooks/use-get-begin-query.tsx 查看文件

@@ -135,24 +135,6 @@ export const useBuildVariableOptions = (nodeId?: string) => {
return options;
};

export const useGetComponentLabelByValue = (nodeId: string) => {
const options = useBuildVariableOptions(nodeId);

const flattenOptions = useMemo(() => {
return options.reduce<DefaultOptionType[]>((pre, cur) => {
return [...pre, ...cur.options];
}, []);
}, [options]);

const getLabel = useCallback(
(val?: string) => {
return flattenOptions.find((x) => x.value === val)?.label;
},
[flattenOptions],
);
return getLabel;
};

export function useBuildQueryVariableOptions() {
const { data } = useFetchAgent();
const node = useContext(AgentFormContext);
@@ -220,3 +202,21 @@ export function useBuildComponentIdAndBeginOptions(

return [...beginOptions, ...componentIdOptions];
}

export const useGetComponentLabelByValue = (nodeId: string) => {
const options = useBuildComponentIdAndBeginOptions(nodeId);

const flattenOptions = useMemo(() => {
return options.reduce<DefaultOptionType[]>((pre, cur) => {
return [...pre, ...cur.options];
}, []);
}, [options]);

const getLabel = useCallback(
(val?: string) => {
return flattenOptions.find((x) => x.value === val)?.label;
},
[flattenOptions],
);
return getLabel;
};

+ 27
- 9
web/src/pages/agent/store.ts 查看文件

@@ -56,6 +56,7 @@ export type RFState = {
source: string,
sourceHandle?: string | null,
target?: string | null,
isConnecting?: boolean,
) => void;
deletePreviousEdgeOfClassificationNode: (connection: Connection) => void;
duplicateNode: (id: string, name: string) => void;
@@ -204,7 +205,7 @@ const useGraphStore = create<RFState>()(
]);
break;
case Operator.Switch: {
updateSwitchFormData(source, sourceHandle, target);
updateSwitchFormData(source, sourceHandle, target, true);
break;
}
default:
@@ -219,7 +220,7 @@ const useGraphStore = create<RFState>()(
const anchoredNodes = [
Operator.Categorize,
Operator.Relevant,
Operator.Switch,
// Operator.Switch,
];
if (
anchoredNodes.some(
@@ -303,7 +304,7 @@ const useGraphStore = create<RFState>()(
const currentEdge = edges.find((x) => x.id === id);

if (currentEdge) {
const { source, sourceHandle } = currentEdge;
const { source, sourceHandle, target } = currentEdge;
const operatorType = getOperatorTypeFromId(source);
// After deleting the edge, set the corresponding field in the node's form field to undefined
switch (operatorType) {
@@ -321,7 +322,7 @@ const useGraphStore = create<RFState>()(
]);
break;
case Operator.Switch: {
updateSwitchFormData(source, sourceHandle, undefined);
updateSwitchFormData(source, sourceHandle, target, false);
break;
}
default:
@@ -402,15 +403,32 @@ const useGraphStore = create<RFState>()(

return nextNodes;
},
updateSwitchFormData: (source, sourceHandle, target) => {
const { updateNodeForm } = get();
updateSwitchFormData: (source, sourceHandle, target, isConnecting) => {
const { updateNodeForm, edges } = get();
if (sourceHandle) {
// A handle will connect to multiple downstream nodes
let currentHandleTargets = edges
.filter(
(x) =>
x.source === source &&
x.sourceHandle === sourceHandle &&
typeof x.target === 'string',
)
.map((x) => x.target);

let targets: string[] = currentHandleTargets;
if (target) {
if (!isConnecting) {
targets = currentHandleTargets.filter((x) => x !== target);
}
}

if (sourceHandle === SwitchElseTo) {
updateNodeForm(source, target, [SwitchElseTo]);
updateNodeForm(source, targets, [SwitchElseTo]);
} else {
const operatorIndex = getOperatorIndex(sourceHandle);
if (operatorIndex) {
updateNodeForm(source, target, [
updateNodeForm(source, targets, [
'conditions',
Number(operatorIndex) - 1, // The index is the conditions form index
'to',
@@ -448,7 +466,7 @@ const useGraphStore = create<RFState>()(
return generateNodeNamesWithIncreasingIndex(name, nodes);
},
})),
{ name: 'graph' },
{ name: 'graph', trace: true },
),
);


Loading…
取消
儲存