浏览代码

fix retrival resource miss in chatflow (#18307)

tags/1.3.0
Jyong 6 个月前
父节点
当前提交
e90c532c3a
没有帐户链接到提交者的电子邮件
共有 3 个文件被更改,包括 2 次插入30 次删除
  1. 1
    0
      api/controllers/web/message.py
  2. 0
    24
      api/core/callback_handler/index_tool_callback_handler.py
  3. 1
    6
      api/models/model.py

+ 1
- 0
api/controllers/web/message.py 查看文件

@@ -46,6 +46,7 @@ class MessageListApi(WebApiResource):
"retriever_resources": fields.List(fields.Nested(retriever_resource_fields)),
"created_at": TimestampField,
"agent_thoughts": fields.List(fields.Nested(agent_thought_fields)),
"metadata": fields.Raw(attribute="message_metadata_dict"),
"status": fields.String,
"error": fields.String,
}

+ 0
- 24
api/core/callback_handler/index_tool_callback_handler.py 查看文件

@@ -6,7 +6,6 @@ from core.rag.models.document import Document
from extensions.ext_database import db
from models.dataset import ChildChunk, DatasetQuery, DocumentSegment
from models.dataset import Document as DatasetDocument
from models.model import DatasetRetrieverResource


class DatasetIndexToolCallbackHandler:
@@ -71,29 +70,6 @@ class DatasetIndexToolCallbackHandler:

def return_retriever_resource_info(self, resource: list):
"""Handle return_retriever_resource_info."""
if resource and len(resource) > 0:
for item in resource:
dataset_retriever_resource = DatasetRetrieverResource(
message_id=self._message_id,
position=item.get("position") or 0,
dataset_id=item.get("dataset_id"),
dataset_name=item.get("dataset_name"),
document_id=item.get("document_id"),
document_name=item.get("document_name"),
data_source_type=item.get("data_source_type"),
segment_id=item.get("segment_id"),
score=item.get("score") if "score" in item else None,
hit_count=item.get("hit_count") if "hit_count" in item else None,
word_count=item.get("word_count") if "word_count" in item else None,
segment_position=item.get("segment_position") if "segment_position" in item else None,
index_node_hash=item.get("index_node_hash") if "index_node_hash" in item else None,
content=item.get("content"),
retriever_from=item.get("retriever_from"),
created_by=self._user_id,
)
db.session.add(dataset_retriever_resource)
db.session.commit()

self._queue_manager.publish(
QueueRetrieverResourcesEvent(retriever_resources=resource), PublishFrom.APPLICATION_MANAGER
)

+ 1
- 6
api/models/model.py 查看文件

@@ -1091,12 +1091,7 @@ class Message(db.Model): # type: ignore[name-defined]

@property
def retriever_resources(self):
return (
db.session.query(DatasetRetrieverResource)
.filter(DatasetRetrieverResource.message_id == self.id)
.order_by(DatasetRetrieverResource.position.asc())
.all()
)
return self.message_metadata_dict.get("retriever_resources") if self.message_metadata else []

@property
def message_files(self):

正在加载...
取消
保存