Browse Source

Merge branch 'feat/rag-2' of https://github.com/langgenius/dify into feat/rag-2

tags/2.0.0-beta.1
twwu 2 months ago
parent
commit
9a79d8941e

+ 22
- 0
api/celery_entrypoint.py View File

@@ -0,0 +1,22 @@
import logging

import psycogreen.gevent as pscycogreen_gevent # type: ignore
from grpc.experimental import gevent as grpc_gevent # type: ignore

_logger = logging.getLogger(__name__)


def _log(message: str):
print(message, flush=True)


# grpc gevent
grpc_gevent.init_gevent()
_log("gRPC patched with gevent.")
pscycogreen_gevent.patch_psycopg()
_log("psycopg2 patched with gevent.")


from app import app, celery

__all__ = ["app", "celery"]

+ 4
- 32
api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py View File

@@ -10,6 +10,10 @@ from controllers.console import api
from controllers.console.app.error import (
DraftWorkflowNotExist,
)
from controllers.console.app.workflow_draft_variable import (
_WORKFLOW_DRAFT_VARIABLE_FIELDS,
_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS,
)
from controllers.console.datasets.wraps import get_rag_pipeline
from controllers.console.wraps import account_initialization_required, setup_required
from controllers.web.error import InvalidArgumentError, NotFoundError
@@ -70,38 +74,6 @@ def _create_pagination_parser():
return parser


_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS = {
"id": fields.String,
"type": fields.String(attribute=lambda model: model.get_variable_type()),
"name": fields.String,
"description": fields.String,
"selector": fields.List(fields.String, attribute=lambda model: model.get_selector()),
"value_type": fields.String,
"edited": fields.Boolean(attribute=lambda model: model.edited),
"visible": fields.Boolean,
}

_WORKFLOW_DRAFT_VARIABLE_FIELDS = dict(
_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS,
value=fields.Raw(attribute=_serialize_var_value),
)

_WORKFLOW_DRAFT_ENV_VARIABLE_FIELDS = {
"id": fields.String,
"type": fields.String(attribute=lambda _: "env"),
"name": fields.String,
"description": fields.String,
"selector": fields.List(fields.String, attribute=lambda model: model.get_selector()),
"value_type": fields.String,
"edited": fields.Boolean(attribute=lambda model: model.edited),
"visible": fields.Boolean,
}

_WORKFLOW_DRAFT_ENV_VARIABLE_LIST_FIELDS = {
"items": fields.List(fields.Nested(_WORKFLOW_DRAFT_ENV_VARIABLE_FIELDS)),
}


def _get_items(var_list: WorkflowDraftVariableList) -> list[WorkflowDraftVariable]:
return var_list.variables


+ 3
- 1
api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py View File

@@ -984,6 +984,7 @@ class RagPipelineDatasourceVariableApi(Resource):
)
return workflow_node_execution


class RagPipelineRecommendedPluginApi(Resource):
@setup_required
@login_required
@@ -993,6 +994,7 @@ class RagPipelineRecommendedPluginApi(Resource):
recommended_plugins = rag_pipeline_service.get_recommended_plugins()
return recommended_plugins


api.add_resource(
DraftRagPipelineApi,
"/rag/pipelines/<uuid:pipeline_id>/workflows/draft",
@@ -1105,4 +1107,4 @@ api.add_resource(
api.add_resource(
RagPipelineRecommendedPluginApi,
"/rag/pipelines/recommended-plugins",
)
)

+ 4
- 0
api/core/app/app_config/workflow_ui_based_app/variables/manager.py View File

@@ -56,6 +56,10 @@ class WorkflowVariablesConfigManager:
full_path = match.group(1)
last_part = full_path.split(".")[-1]
variables_map.pop(last_part)
if value.get("value") and isinstance(value.get("value"), list):
last_part = value.get("value")[-1]
variables_map.pop(last_part)

all_second_step_variables = list(variables_map.values())

for item in all_second_step_variables:

+ 1
- 1
api/docker/entrypoint.sh View File

@@ -30,7 +30,7 @@ if [[ "${MODE}" == "worker" ]]; then
CONCURRENCY_OPTION="-c ${CELERY_WORKER_AMOUNT:-1}"
fi

exec celery -A app.celery worker -P ${CELERY_WORKER_CLASS:-gevent} $CONCURRENCY_OPTION \
exec celery -A celery_entrypoint.celery worker -P ${CELERY_WORKER_CLASS:-gevent} $CONCURRENCY_OPTION \
--max-tasks-per-child ${MAX_TASKS_PER_CHILD:-50} --loglevel ${LOG_LEVEL:-INFO} \
-Q ${CELERY_QUEUES:-dataset,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation}


+ 3
- 3
api/fields/workflow_run_fields.py View File

@@ -116,9 +116,9 @@ workflow_run_node_execution_fields = {
"created_by_account": fields.Nested(simple_account_fields, attribute="created_by_account", allow_null=True),
"created_by_end_user": fields.Nested(simple_end_user_fields, attribute="created_by_end_user", allow_null=True),
"finished_at": TimestampField,
# "inputs_truncated": fields.Boolean,
# "outputs_truncated": fields.Boolean,
# "process_data_truncated": fields.Boolean,
"inputs_truncated": fields.Boolean,
"outputs_truncated": fields.Boolean,
"process_data_truncated": fields.Boolean,
}

workflow_run_node_execution_list_fields = {

+ 10
- 0
api/gunicorn.conf.py View File

@@ -0,0 +1,10 @@
import psycogreen.gevent as pscycogreen_gevent # type: ignore
from grpc.experimental import gevent as grpc_gevent # type: ignore


def post_fork(server, worker):
# grpc gevent
grpc_gevent.init_gevent()
server.log.info("gRPC patched with gevent.")
pscycogreen_gevent.patch_psycopg()
server.log.info("psycopg2 patched with gevent.")

+ 31
- 19
api/services/rag_pipeline/rag_pipeline.py View File

@@ -512,20 +512,28 @@ class RagPipelineService:

datasource_parameters = datasource_node_data.get("datasource_parameters", {})
for key, value in datasource_parameters.items():
if value.get("value") and isinstance(value.get("value"), str):
param_value = value.get("value")

if not param_value:
variables_map[key] = param_value
elif isinstance(param_value, str):
# handle string type parameter value, check if it contains variable reference pattern
pattern = r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z0-9_][a-zA-Z0-9_]{0,29}){1,10})#\}\}"
match = re.match(pattern, value["value"])
match = re.match(pattern, param_value)
if match:
# extract variable path and try to get value from user inputs
full_path = match.group(1)
last_part = full_path.split(".")[-1]
if last_part in user_inputs:
variables_map[key] = user_inputs[last_part]
else:
variables_map[key] = value["value"]
variables_map[key] = user_inputs.get(last_part, param_value)
else:
variables_map[key] = value["value"]
variables_map[key] = param_value
elif isinstance(param_value, list) and param_value:
# handle list type parameter value, check if the last element is in user inputs
last_part = param_value[-1]
variables_map[key] = user_inputs.get(last_part, param_value)
else:
variables_map[key] = value["value"]
# other type directly use original value
variables_map[key] = param_value

from core.datasource.datasource_manager import DatasourceManager

@@ -931,6 +939,10 @@ class RagPipelineService:
full_path = match.group(1)
last_part = full_path.split(".")[-1]
user_input_variables.append(variables_map.get(last_part, {}))
elif value.get("value") and isinstance(value.get("value"), list):
last_part = value.get("value")[-1]
user_input_variables.append(variables_map.get(last_part, {}))

return user_input_variables

def get_second_step_parameters(self, pipeline: Pipeline, node_id: str, is_draft: bool = False) -> list[dict]:
@@ -968,6 +980,9 @@ class RagPipelineService:
full_path = match.group(1)
last_part = full_path.split(".")[-1]
variables_map.pop(last_part)
elif value.get("value") and isinstance(value.get("value"), list):
last_part = value.get("value")[-1]
variables_map.pop(last_part)
all_second_step_variables = list(variables_map.values())
datasource_provider_variables = [
item
@@ -1147,18 +1162,15 @@ class RagPipelineService:
def get_node_last_run(
self, pipeline: Pipeline, workflow: Workflow, node_id: str
) -> WorkflowNodeExecutionModel | None:
# TODO(QuantumGhost): This query is not fully covered by index.
criteria = (
WorkflowNodeExecutionModel.tenant_id == pipeline.tenant_id,
WorkflowNodeExecutionModel.app_id == pipeline.id,
WorkflowNodeExecutionModel.workflow_id == workflow.id,
WorkflowNodeExecutionModel.node_id == node_id,
node_execution_service_repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository(
sessionmaker(db.engine)
)
node_exec = (
db.session.query(WorkflowNodeExecutionModel)
.filter(*criteria)
.order_by(WorkflowNodeExecutionModel.created_at.desc())
.first()

node_exec = node_execution_service_repo.get_node_last_execution(
tenant_id=pipeline.tenant_id,
app_id=pipeline.id,
workflow_id=workflow.id,
node_id=node_id,
)
return node_exec


+ 3
- 3
web/app/components/workflow/nodes/data-source/default.ts View File

@@ -8,7 +8,7 @@ import {
LOCAL_FILE_OUTPUT,
} from './constants'
import { VarType as VarKindType } from '@/app/components/workflow/nodes/tool/types'
import type { AnyObj } from '../_base/components/variable/match-schema-type'
import { getMatchedSchemaType } from '../_base/components/variable/use-match-schema-type'

const i18nPrefix = 'workflow.errorMsg'

@@ -54,7 +54,7 @@ const nodeDefault: NodeDefault<DataSourceNodeType> = {
errorMessage,
}
},
getOutputVars(payload, allPluginInfoList, ragVars = [], { getMatchedSchemaType } = { getMatchedSchemaType: (_obj: AnyObj) => '' }) {
getOutputVars(payload, allPluginInfoList, ragVars = [], { schemaTypeDefinitions } = { schemaTypeDefinitions: [] }) {
const {
plugin_id,
datasource_name,
@@ -74,7 +74,7 @@ const nodeDefault: NodeDefault<DataSourceNodeType> = {
let type = dataType === 'array'
? `array[${output.items?.type.slice(0, 1).toLocaleLowerCase()}${output.items?.type.slice(1)}]`
: `${dataType.slice(0, 1).toLocaleLowerCase()}${dataType.slice(1)}`
const schemaType = getMatchedSchemaType?.(output)
const schemaType = getMatchedSchemaType?.(output, schemaTypeDefinitions)

if (type === 'object' && schemaType === 'file')
type = 'file'

Loading…
Cancel
Save