Co-authored-by: jyong <jyong@dify.ai> Co-authored-by: StyleZhang <jasonapring2015@outlook.com>tags/0.3.22
| @@ -29,6 +29,7 @@ model_config_fields = { | |||
| 'suggested_questions': fields.Raw(attribute='suggested_questions_list'), | |||
| 'suggested_questions_after_answer': fields.Raw(attribute='suggested_questions_after_answer_dict'), | |||
| 'speech_to_text': fields.Raw(attribute='speech_to_text_dict'), | |||
| 'retriever_resource': fields.Raw(attribute='retriever_resource_dict'), | |||
| 'more_like_this': fields.Raw(attribute='more_like_this_dict'), | |||
| 'sensitive_word_avoidance': fields.Raw(attribute='sensitive_word_avoidance_dict'), | |||
| 'model': fields.Raw(attribute='model_dict'), | |||
| @@ -42,6 +42,7 @@ class CompletionMessageApi(Resource): | |||
| parser.add_argument('query', type=str, location='json', default='') | |||
| parser.add_argument('model_config', type=dict, required=True, location='json') | |||
| parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json') | |||
| parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json') | |||
| args = parser.parse_args() | |||
| streaming = args['response_mode'] != 'blocking' | |||
| @@ -115,6 +116,7 @@ class ChatMessageApi(Resource): | |||
| parser.add_argument('model_config', type=dict, required=True, location='json') | |||
| parser.add_argument('conversation_id', type=uuid_value, location='json') | |||
| parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json') | |||
| parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json') | |||
| args = parser.parse_args() | |||
| streaming = args['response_mode'] != 'blocking' | |||
| @@ -33,6 +33,7 @@ class CompletionApi(InstalledAppResource): | |||
| parser.add_argument('inputs', type=dict, required=True, location='json') | |||
| parser.add_argument('query', type=str, location='json', default='') | |||
| parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json') | |||
| parser.add_argument('retriever_from', type=str, required=False, default='explore_app', location='json') | |||
| args = parser.parse_args() | |||
| streaming = args['response_mode'] == 'streaming' | |||
| @@ -92,6 +93,7 @@ class ChatApi(InstalledAppResource): | |||
| parser.add_argument('query', type=str, required=True, location='json') | |||
| parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json') | |||
| parser.add_argument('conversation_id', type=uuid_value, location='json') | |||
| parser.add_argument('retriever_from', type=str, required=False, default='explore_app', location='json') | |||
| args = parser.parse_args() | |||
| streaming = args['response_mode'] == 'streaming' | |||
| @@ -30,6 +30,25 @@ class MessageListApi(InstalledAppResource): | |||
| 'rating': fields.String | |||
| } | |||
| retriever_resource_fields = { | |||
| 'id': fields.String, | |||
| 'message_id': fields.String, | |||
| 'position': fields.Integer, | |||
| 'dataset_id': fields.String, | |||
| 'dataset_name': fields.String, | |||
| 'document_id': fields.String, | |||
| 'document_name': fields.String, | |||
| 'data_source_type': fields.String, | |||
| 'segment_id': fields.String, | |||
| 'score': fields.Float, | |||
| 'hit_count': fields.Integer, | |||
| 'word_count': fields.Integer, | |||
| 'segment_position': fields.Integer, | |||
| 'index_node_hash': fields.String, | |||
| 'content': fields.String, | |||
| 'created_at': TimestampField | |||
| } | |||
| message_fields = { | |||
| 'id': fields.String, | |||
| 'conversation_id': fields.String, | |||
| @@ -37,6 +56,7 @@ class MessageListApi(InstalledAppResource): | |||
| 'query': fields.String, | |||
| 'answer': fields.String, | |||
| 'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True), | |||
| 'retriever_resources': fields.List(fields.Nested(retriever_resource_fields)), | |||
| 'created_at': TimestampField | |||
| } | |||
| @@ -24,6 +24,7 @@ class AppParameterApi(InstalledAppResource): | |||
| 'suggested_questions': fields.Raw, | |||
| 'suggested_questions_after_answer': fields.Raw, | |||
| 'speech_to_text': fields.Raw, | |||
| 'retriever_resource': fields.Raw, | |||
| 'more_like_this': fields.Raw, | |||
| 'user_input_form': fields.Raw, | |||
| } | |||
| @@ -39,6 +40,7 @@ class AppParameterApi(InstalledAppResource): | |||
| 'suggested_questions': app_model_config.suggested_questions_list, | |||
| 'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict, | |||
| 'speech_to_text': app_model_config.speech_to_text_dict, | |||
| 'retriever_resource': app_model_config.retriever_resource_dict, | |||
| 'more_like_this': app_model_config.more_like_this_dict, | |||
| 'user_input_form': app_model_config.user_input_form_list | |||
| } | |||
| @@ -29,9 +29,11 @@ class UniversalChatApi(UniversalChatResource): | |||
| parser.add_argument('provider', type=str, required=True, location='json') | |||
| parser.add_argument('model', type=str, required=True, location='json') | |||
| parser.add_argument('tools', type=list, required=True, location='json') | |||
| parser.add_argument('retriever_from', type=str, required=False, default='universal_app', location='json') | |||
| args = parser.parse_args() | |||
| app_model_config = app_model.app_model_config | |||
| app_model_config | |||
| # update app model config | |||
| args['model_config'] = app_model_config.to_dict() | |||
| @@ -36,6 +36,25 @@ class UniversalChatMessageListApi(UniversalChatResource): | |||
| 'created_at': TimestampField | |||
| } | |||
| retriever_resource_fields = { | |||
| 'id': fields.String, | |||
| 'message_id': fields.String, | |||
| 'position': fields.Integer, | |||
| 'dataset_id': fields.String, | |||
| 'dataset_name': fields.String, | |||
| 'document_id': fields.String, | |||
| 'document_name': fields.String, | |||
| 'data_source_type': fields.String, | |||
| 'segment_id': fields.String, | |||
| 'score': fields.Float, | |||
| 'hit_count': fields.Integer, | |||
| 'word_count': fields.Integer, | |||
| 'segment_position': fields.Integer, | |||
| 'index_node_hash': fields.String, | |||
| 'content': fields.String, | |||
| 'created_at': TimestampField | |||
| } | |||
| message_fields = { | |||
| 'id': fields.String, | |||
| 'conversation_id': fields.String, | |||
| @@ -43,6 +62,7 @@ class UniversalChatMessageListApi(UniversalChatResource): | |||
| 'query': fields.String, | |||
| 'answer': fields.String, | |||
| 'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True), | |||
| 'retriever_resources': fields.List(fields.Nested(retriever_resource_fields)), | |||
| 'created_at': TimestampField, | |||
| 'agent_thoughts': fields.List(fields.Nested(agent_thought_fields)) | |||
| } | |||
| @@ -1,4 +1,6 @@ | |||
| # -*- coding:utf-8 -*- | |||
| import json | |||
| from flask_restful import marshal_with, fields | |||
| from controllers.console import api | |||
| @@ -14,6 +16,7 @@ class UniversalChatParameterApi(UniversalChatResource): | |||
| 'suggested_questions': fields.Raw, | |||
| 'suggested_questions_after_answer': fields.Raw, | |||
| 'speech_to_text': fields.Raw, | |||
| 'retriever_resource': fields.Raw, | |||
| } | |||
| @marshal_with(parameters_fields) | |||
| @@ -21,12 +24,14 @@ class UniversalChatParameterApi(UniversalChatResource): | |||
| """Retrieve app parameters.""" | |||
| app_model = universal_app | |||
| app_model_config = app_model.app_model_config | |||
| app_model_config.retriever_resource = json.dumps({'enabled': True}) | |||
| return { | |||
| 'opening_statement': app_model_config.opening_statement, | |||
| 'suggested_questions': app_model_config.suggested_questions_list, | |||
| 'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict, | |||
| 'speech_to_text': app_model_config.speech_to_text_dict, | |||
| 'retriever_resource': app_model_config.retriever_resource_dict, | |||
| } | |||
| @@ -47,6 +47,7 @@ def universal_chat_app_required(view=None): | |||
| suggested_questions=json.dumps([]), | |||
| suggested_questions_after_answer=json.dumps({'enabled': True}), | |||
| speech_to_text=json.dumps({'enabled': True}), | |||
| retriever_resource=json.dumps({'enabled': True}), | |||
| more_like_this=None, | |||
| sensitive_word_avoidance=None, | |||
| model=json.dumps({ | |||
| @@ -25,6 +25,7 @@ class AppParameterApi(AppApiResource): | |||
| 'suggested_questions': fields.Raw, | |||
| 'suggested_questions_after_answer': fields.Raw, | |||
| 'speech_to_text': fields.Raw, | |||
| 'retriever_resource': fields.Raw, | |||
| 'more_like_this': fields.Raw, | |||
| 'user_input_form': fields.Raw, | |||
| } | |||
| @@ -39,6 +40,7 @@ class AppParameterApi(AppApiResource): | |||
| 'suggested_questions': app_model_config.suggested_questions_list, | |||
| 'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict, | |||
| 'speech_to_text': app_model_config.speech_to_text_dict, | |||
| 'retriever_resource': app_model_config.retriever_resource_dict, | |||
| 'more_like_this': app_model_config.more_like_this_dict, | |||
| 'user_input_form': app_model_config.user_input_form_list | |||
| } | |||
| @@ -30,6 +30,8 @@ class CompletionApi(AppApiResource): | |||
| parser.add_argument('query', type=str, location='json', default='') | |||
| parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json') | |||
| parser.add_argument('user', type=str, location='json') | |||
| parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json') | |||
| args = parser.parse_args() | |||
| streaming = args['response_mode'] == 'streaming' | |||
| @@ -91,6 +93,8 @@ class ChatApi(AppApiResource): | |||
| parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json') | |||
| parser.add_argument('conversation_id', type=uuid_value, location='json') | |||
| parser.add_argument('user', type=str, location='json') | |||
| parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json') | |||
| args = parser.parse_args() | |||
| streaming = args['response_mode'] == 'streaming' | |||
| @@ -16,6 +16,24 @@ class MessageListApi(AppApiResource): | |||
| feedback_fields = { | |||
| 'rating': fields.String | |||
| } | |||
| retriever_resource_fields = { | |||
| 'id': fields.String, | |||
| 'message_id': fields.String, | |||
| 'position': fields.Integer, | |||
| 'dataset_id': fields.String, | |||
| 'dataset_name': fields.String, | |||
| 'document_id': fields.String, | |||
| 'document_name': fields.String, | |||
| 'data_source_type': fields.String, | |||
| 'segment_id': fields.String, | |||
| 'score': fields.Float, | |||
| 'hit_count': fields.Integer, | |||
| 'word_count': fields.Integer, | |||
| 'segment_position': fields.Integer, | |||
| 'index_node_hash': fields.String, | |||
| 'content': fields.String, | |||
| 'created_at': TimestampField | |||
| } | |||
| message_fields = { | |||
| 'id': fields.String, | |||
| @@ -24,6 +42,7 @@ class MessageListApi(AppApiResource): | |||
| 'query': fields.String, | |||
| 'answer': fields.String, | |||
| 'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True), | |||
| 'retriever_resources': fields.List(fields.Nested(retriever_resource_fields)), | |||
| 'created_at': TimestampField | |||
| } | |||
| @@ -24,6 +24,7 @@ class AppParameterApi(WebApiResource): | |||
| 'suggested_questions': fields.Raw, | |||
| 'suggested_questions_after_answer': fields.Raw, | |||
| 'speech_to_text': fields.Raw, | |||
| 'retriever_resource': fields.Raw, | |||
| 'more_like_this': fields.Raw, | |||
| 'user_input_form': fields.Raw, | |||
| } | |||
| @@ -38,6 +39,7 @@ class AppParameterApi(WebApiResource): | |||
| 'suggested_questions': app_model_config.suggested_questions_list, | |||
| 'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict, | |||
| 'speech_to_text': app_model_config.speech_to_text_dict, | |||
| 'retriever_resource': app_model_config.retriever_resource_dict, | |||
| 'more_like_this': app_model_config.more_like_this_dict, | |||
| 'user_input_form': app_model_config.user_input_form_list | |||
| } | |||
| @@ -31,6 +31,8 @@ class CompletionApi(WebApiResource): | |||
| parser.add_argument('inputs', type=dict, required=True, location='json') | |||
| parser.add_argument('query', type=str, location='json', default='') | |||
| parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json') | |||
| parser.add_argument('retriever_from', type=str, required=False, default='web_app', location='json') | |||
| args = parser.parse_args() | |||
| streaming = args['response_mode'] == 'streaming' | |||
| @@ -88,6 +90,8 @@ class ChatApi(WebApiResource): | |||
| parser.add_argument('query', type=str, required=True, location='json') | |||
| parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json') | |||
| parser.add_argument('conversation_id', type=uuid_value, location='json') | |||
| parser.add_argument('retriever_from', type=str, required=False, default='web_app', location='json') | |||
| args = parser.parse_args() | |||
| streaming = args['response_mode'] == 'streaming' | |||
| @@ -29,6 +29,25 @@ class MessageListApi(WebApiResource): | |||
| 'rating': fields.String | |||
| } | |||
| retriever_resource_fields = { | |||
| 'id': fields.String, | |||
| 'message_id': fields.String, | |||
| 'position': fields.Integer, | |||
| 'dataset_id': fields.String, | |||
| 'dataset_name': fields.String, | |||
| 'document_id': fields.String, | |||
| 'document_name': fields.String, | |||
| 'data_source_type': fields.String, | |||
| 'segment_id': fields.String, | |||
| 'score': fields.Float, | |||
| 'hit_count': fields.Integer, | |||
| 'word_count': fields.Integer, | |||
| 'segment_position': fields.Integer, | |||
| 'index_node_hash': fields.String, | |||
| 'content': fields.String, | |||
| 'created_at': TimestampField | |||
| } | |||
| message_fields = { | |||
| 'id': fields.String, | |||
| 'conversation_id': fields.String, | |||
| @@ -36,6 +55,7 @@ class MessageListApi(WebApiResource): | |||
| 'query': fields.String, | |||
| 'answer': fields.String, | |||
| 'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True), | |||
| 'retriever_resources': fields.List(fields.Nested(retriever_resource_fields)), | |||
| 'created_at': TimestampField | |||
| } | |||
| @@ -1,3 +1,4 @@ | |||
| import json | |||
| from typing import Tuple, List, Any, Union, Sequence, Optional, cast | |||
| from langchain.agents import OpenAIFunctionsAgent, BaseSingleActionAgent | |||
| @@ -53,6 +54,10 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent): | |||
| tool = next(iter(self.tools)) | |||
| tool = cast(DatasetRetrieverTool, tool) | |||
| rst = tool.run(tool_input={'query': kwargs['input']}) | |||
| # output = '' | |||
| # rst_json = json.loads(rst) | |||
| # for item in rst_json: | |||
| # output += f'{item["content"]}\n' | |||
| return AgentFinish(return_values={"output": rst}, log=rst) | |||
| if intermediate_steps: | |||
| @@ -64,12 +64,9 @@ class DatasetToolCallbackHandler(BaseCallbackHandler): | |||
| llm_prefix: Optional[str] = None, | |||
| **kwargs: Any, | |||
| ) -> None: | |||
| # kwargs={'name': 'Search'} | |||
| # llm_prefix='Thought:' | |||
| # observation_prefix='Observation: ' | |||
| # output='53 years' | |||
| pass | |||
| def on_tool_error( | |||
| self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any | |||
| ) -> None: | |||
| @@ -2,6 +2,7 @@ from typing import List | |||
| from langchain.schema import Document | |||
| from core.conversation_message_task import ConversationMessageTask | |||
| from extensions.ext_database import db | |||
| from models.dataset import DocumentSegment | |||
| @@ -9,8 +10,9 @@ from models.dataset import DocumentSegment | |||
| class DatasetIndexToolCallbackHandler: | |||
| """Callback handler for dataset tool.""" | |||
| def __init__(self, dataset_id: str) -> None: | |||
| def __init__(self, dataset_id: str, conversation_message_task: ConversationMessageTask) -> None: | |||
| self.dataset_id = dataset_id | |||
| self.conversation_message_task = conversation_message_task | |||
| def on_tool_end(self, documents: List[Document]) -> None: | |||
| """Handle tool end.""" | |||
| @@ -27,3 +29,7 @@ class DatasetIndexToolCallbackHandler: | |||
| ) | |||
| db.session.commit() | |||
| def return_retriever_resource_info(self, resource: List): | |||
| """Handle return_retriever_resource_info.""" | |||
| self.conversation_message_task.on_dataset_query_finish(resource) | |||
| @@ -1,3 +1,4 @@ | |||
| import json | |||
| import logging | |||
| import re | |||
| from typing import Optional, List, Union, Tuple | |||
| @@ -19,13 +20,15 @@ from core.orchestrator_rule_parser import OrchestratorRuleParser | |||
| from core.prompt.prompt_builder import PromptBuilder | |||
| from core.prompt.prompt_template import JinjaPromptTemplate | |||
| from core.prompt.prompts import MORE_LIKE_THIS_GENERATE_PROMPT | |||
| from models.dataset import DocumentSegment, Dataset, Document | |||
| from models.model import App, AppModelConfig, Account, Conversation, Message, EndUser | |||
| class Completion: | |||
| @classmethod | |||
| def generate(cls, task_id: str, app: App, app_model_config: AppModelConfig, query: str, inputs: dict, | |||
| user: Union[Account, EndUser], conversation: Optional[Conversation], streaming: bool, is_override: bool = False): | |||
| user: Union[Account, EndUser], conversation: Optional[Conversation], streaming: bool, | |||
| is_override: bool = False, retriever_from: str = 'dev'): | |||
| """ | |||
| errors: ProviderTokenNotInitError | |||
| """ | |||
| @@ -96,7 +99,6 @@ class Completion: | |||
| should_use_agent = agent_executor.should_use_agent(query) | |||
| if should_use_agent: | |||
| agent_execute_result = agent_executor.run(query) | |||
| # run the final llm | |||
| try: | |||
| cls.run_final_llm( | |||
| @@ -118,7 +120,8 @@ class Completion: | |||
| return | |||
| @classmethod | |||
| def run_final_llm(cls, model_instance: BaseLLM, mode: str, app_model_config: AppModelConfig, query: str, inputs: dict, | |||
| def run_final_llm(cls, model_instance: BaseLLM, mode: str, app_model_config: AppModelConfig, query: str, | |||
| inputs: dict, | |||
| agent_execute_result: Optional[AgentExecuteResult], | |||
| conversation_message_task: ConversationMessageTask, | |||
| memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory]): | |||
| @@ -150,7 +153,6 @@ class Completion: | |||
| callbacks=[LLMCallbackHandler(model_instance, conversation_message_task)], | |||
| fake_response=fake_response | |||
| ) | |||
| return response | |||
| @classmethod | |||
| @@ -1,6 +1,6 @@ | |||
| import decimal | |||
| import json | |||
| from typing import Optional, Union | |||
| from typing import Optional, Union, List | |||
| from core.callback_handler.entity.agent_loop import AgentLoop | |||
| from core.callback_handler.entity.dataset_query import DatasetQueryObj | |||
| @@ -15,7 +15,8 @@ from events.message_event import message_was_created | |||
| from extensions.ext_database import db | |||
| from extensions.ext_redis import redis_client | |||
| from models.dataset import DatasetQuery | |||
| from models.model import AppModelConfig, Conversation, Account, Message, EndUser, App, MessageAgentThought, MessageChain | |||
| from models.model import AppModelConfig, Conversation, Account, Message, EndUser, App, MessageAgentThought, \ | |||
| MessageChain, DatasetRetrieverResource | |||
| class ConversationMessageTask: | |||
| @@ -41,6 +42,8 @@ class ConversationMessageTask: | |||
| self.message = None | |||
| self.retriever_resource = None | |||
| self.model_dict = self.app_model_config.model_dict | |||
| self.provider_name = self.model_dict.get('provider') | |||
| self.model_name = self.model_dict.get('name') | |||
| @@ -157,7 +160,8 @@ class ConversationMessageTask: | |||
| self.message.message_tokens = message_tokens | |||
| self.message.message_unit_price = message_unit_price | |||
| self.message.message_price_unit = message_price_unit | |||
| self.message.answer = PromptBuilder.process_template(llm_message.completion.strip()) if llm_message.completion else '' | |||
| self.message.answer = PromptBuilder.process_template( | |||
| llm_message.completion.strip()) if llm_message.completion else '' | |||
| self.message.answer_tokens = answer_tokens | |||
| self.message.answer_unit_price = answer_unit_price | |||
| self.message.answer_price_unit = answer_price_unit | |||
| @@ -256,7 +260,36 @@ class ConversationMessageTask: | |||
| db.session.add(dataset_query) | |||
| def on_dataset_query_finish(self, resource: List): | |||
| if resource and len(resource) > 0: | |||
| for item in resource: | |||
| dataset_retriever_resource = DatasetRetrieverResource( | |||
| message_id=self.message.id, | |||
| position=item.get('position'), | |||
| dataset_id=item.get('dataset_id'), | |||
| dataset_name=item.get('dataset_name'), | |||
| document_id=item.get('document_id'), | |||
| document_name=item.get('document_name'), | |||
| data_source_type=item.get('data_source_type'), | |||
| segment_id=item.get('segment_id'), | |||
| score=item.get('score') if 'score' in item else None, | |||
| hit_count=item.get('hit_count') if 'hit_count' else None, | |||
| word_count=item.get('word_count') if 'word_count' in item else None, | |||
| segment_position=item.get('segment_position') if 'segment_position' in item else None, | |||
| index_node_hash=item.get('index_node_hash') if 'index_node_hash' in item else None, | |||
| content=item.get('content'), | |||
| retriever_from=item.get('retriever_from'), | |||
| created_by=self.user.id | |||
| ) | |||
| db.session.add(dataset_retriever_resource) | |||
| db.session.flush() | |||
| self.retriever_resource = resource | |||
| def message_end(self): | |||
| self._pub_handler.pub_message_end(self.retriever_resource) | |||
| def end(self): | |||
| self._pub_handler.pub_message_end(self.retriever_resource) | |||
| self._pub_handler.pub_end() | |||
| @@ -350,6 +383,23 @@ class PubHandler: | |||
| self.pub_end() | |||
| raise ConversationTaskStoppedException() | |||
| def pub_message_end(self, retriever_resource: List): | |||
| content = { | |||
| 'event': 'message_end', | |||
| 'data': { | |||
| 'task_id': self._task_id, | |||
| 'message_id': self._message.id, | |||
| 'mode': self._conversation.mode, | |||
| 'conversation_id': self._conversation.id | |||
| } | |||
| } | |||
| if retriever_resource: | |||
| content['data']['retriever_resources'] = retriever_resource | |||
| redis_client.publish(self._channel, json.dumps(content)) | |||
| if self._is_stopped(): | |||
| self.pub_end() | |||
| raise ConversationTaskStoppedException() | |||
| def pub_end(self): | |||
| content = { | |||
| @@ -74,7 +74,7 @@ class KeywordTableIndex(BaseIndex): | |||
| DocumentSegment.document_id == document_id | |||
| ).all() | |||
| ids = [segment.id for segment in segments] | |||
| ids = [segment.index_node_id for segment in segments] | |||
| keyword_table = self._get_dataset_keyword_table() | |||
| keyword_table = self._delete_ids_from_keyword_table(keyword_table, ids) | |||
| @@ -113,6 +113,25 @@ class QdrantVectorIndex(BaseVectorIndex): | |||
| ], | |||
| )) | |||
| def delete_by_ids(self, ids: list[str]) -> None: | |||
| if self._is_origin(): | |||
| self.recreate_dataset(self.dataset) | |||
| return | |||
| vector_store = self._get_vector_store() | |||
| vector_store = cast(self._get_vector_store_class(), vector_store) | |||
| from qdrant_client.http import models | |||
| for node_id in ids: | |||
| vector_store.del_texts(models.Filter( | |||
| must=[ | |||
| models.FieldCondition( | |||
| key="metadata.doc_id", | |||
| match=models.MatchValue(value=node_id), | |||
| ), | |||
| ], | |||
| )) | |||
| def _is_origin(self): | |||
| if self.dataset.index_struct_dict: | |||
| class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix'] | |||
| @@ -8,6 +8,7 @@ class LLMRunResult(BaseModel): | |||
| content: str | |||
| prompt_tokens: int | |||
| completion_tokens: int | |||
| source: list = None | |||
| class MessageType(enum.Enum): | |||
| @@ -36,8 +36,8 @@ class OrchestratorRuleParser: | |||
| self.app_model_config = app_model_config | |||
| def to_agent_executor(self, conversation_message_task: ConversationMessageTask, memory: Optional[BaseChatMemory], | |||
| rest_tokens: int, chain_callback: MainChainGatherCallbackHandler) \ | |||
| -> Optional[AgentExecutor]: | |||
| rest_tokens: int, chain_callback: MainChainGatherCallbackHandler, | |||
| return_resource: bool = False, retriever_from: str = 'dev') -> Optional[AgentExecutor]: | |||
| if not self.app_model_config.agent_mode_dict: | |||
| return None | |||
| @@ -74,7 +74,7 @@ class OrchestratorRuleParser: | |||
| # only OpenAI chat model (include Azure) support function call, use ReACT instead | |||
| if agent_model_instance.model_mode != ModelMode.CHAT \ | |||
| or agent_model_instance.model_provider.provider_name not in ['openai', 'azure_openai']: | |||
| or agent_model_instance.model_provider.provider_name not in ['openai', 'azure_openai']: | |||
| if planning_strategy in [PlanningStrategy.FUNCTION_CALL, PlanningStrategy.MULTI_FUNCTION_CALL]: | |||
| planning_strategy = PlanningStrategy.REACT | |||
| elif planning_strategy == PlanningStrategy.ROUTER: | |||
| @@ -99,7 +99,9 @@ class OrchestratorRuleParser: | |||
| tool_configs=tool_configs, | |||
| conversation_message_task=conversation_message_task, | |||
| rest_tokens=rest_tokens, | |||
| callbacks=[agent_callback, DifyStdOutCallbackHandler()] | |||
| callbacks=[agent_callback, DifyStdOutCallbackHandler()], | |||
| return_resource=return_resource, | |||
| retriever_from=retriever_from | |||
| ) | |||
| if len(tools) == 0: | |||
| @@ -145,8 +147,10 @@ class OrchestratorRuleParser: | |||
| return None | |||
| def to_tools(self, agent_model_instance: BaseLLM, tool_configs: list, conversation_message_task: ConversationMessageTask, | |||
| rest_tokens: int, callbacks: Callbacks = None) -> list[BaseTool]: | |||
| def to_tools(self, agent_model_instance: BaseLLM, tool_configs: list, | |||
| conversation_message_task: ConversationMessageTask, | |||
| rest_tokens: int, callbacks: Callbacks = None, return_resource: bool = False, | |||
| retriever_from: str = 'dev') -> list[BaseTool]: | |||
| """ | |||
| Convert app agent tool configs to tools | |||
| @@ -155,6 +159,8 @@ class OrchestratorRuleParser: | |||
| :param tool_configs: app agent tool configs | |||
| :param conversation_message_task: | |||
| :param callbacks: | |||
| :param return_resource: | |||
| :param retriever_from: | |||
| :return: | |||
| """ | |||
| tools = [] | |||
| @@ -166,7 +172,7 @@ class OrchestratorRuleParser: | |||
| tool = None | |||
| if tool_type == "dataset": | |||
| tool = self.to_dataset_retriever_tool(tool_val, conversation_message_task, rest_tokens) | |||
| tool = self.to_dataset_retriever_tool(tool_val, conversation_message_task, rest_tokens, return_resource, retriever_from) | |||
| elif tool_type == "web_reader": | |||
| tool = self.to_web_reader_tool(agent_model_instance) | |||
| elif tool_type == "google_search": | |||
| @@ -183,13 +189,15 @@ class OrchestratorRuleParser: | |||
| return tools | |||
| def to_dataset_retriever_tool(self, tool_config: dict, conversation_message_task: ConversationMessageTask, | |||
| rest_tokens: int) \ | |||
| rest_tokens: int, return_resource: bool = False, retriever_from: str = 'dev') \ | |||
| -> Optional[BaseTool]: | |||
| """ | |||
| A dataset tool is a tool that can be used to retrieve information from a dataset | |||
| :param rest_tokens: | |||
| :param tool_config: | |||
| :param conversation_message_task: | |||
| :param return_resource: | |||
| :param retriever_from: | |||
| :return: | |||
| """ | |||
| # get dataset from dataset id | |||
| @@ -208,7 +216,10 @@ class OrchestratorRuleParser: | |||
| tool = DatasetRetrieverTool.from_dataset( | |||
| dataset=dataset, | |||
| k=k, | |||
| callbacks=[DatasetToolCallbackHandler(conversation_message_task)] | |||
| callbacks=[DatasetToolCallbackHandler(conversation_message_task)], | |||
| conversation_message_task=conversation_message_task, | |||
| return_resource=return_resource, | |||
| retriever_from=retriever_from | |||
| ) | |||
| return tool | |||
| @@ -10,4 +10,4 @@ | |||
| ], | |||
| "query_prompt": "\n\nHuman: {{query}}\n\nAssistant: ", | |||
| "stops": ["\nHuman:", "</histories>"] | |||
| } | |||
| } | |||
| @@ -105,7 +105,7 @@ GENERATOR_QA_PROMPT = ( | |||
| 'Step 3: Decompose or combine multiple pieces of information and concepts.\n' | |||
| 'Step 4: Generate 20 questions and answers based on these key information and concepts.' | |||
| 'The questions should be clear and detailed, and the answers should be detailed and complete.\n' | |||
| "Answer must be the language:{language} and in the following format: Q1:\nA1:\nQ2:\nA2:...\n" | |||
| "Answer according to the the language:{language} and in the following format: Q1:\nA1:\nQ2:\nA2:...\n" | |||
| ) | |||
| RULE_CONFIG_GENERATE_TEMPLATE = """Given MY INTENDED AUDIENCES and HOPING TO SOLVE using a language model, please select \ | |||
| @@ -1,3 +1,4 @@ | |||
| import json | |||
| from typing import Type | |||
| from flask import current_app | |||
| @@ -5,13 +6,14 @@ from langchain.tools import BaseTool | |||
| from pydantic import Field, BaseModel | |||
| from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler | |||
| from core.conversation_message_task import ConversationMessageTask | |||
| from core.embedding.cached_embedding import CacheEmbedding | |||
| from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex, KeywordTableConfig | |||
| from core.index.vector_index.vector_index import VectorIndex | |||
| from core.model_providers.error import LLMBadRequestError, ProviderTokenNotInitError | |||
| from core.model_providers.model_factory import ModelFactory | |||
| from extensions.ext_database import db | |||
| from models.dataset import Dataset, DocumentSegment | |||
| from models.dataset import Dataset, DocumentSegment, Document | |||
| class DatasetRetrieverToolInput(BaseModel): | |||
| @@ -27,6 +29,10 @@ class DatasetRetrieverTool(BaseTool): | |||
| tenant_id: str | |||
| dataset_id: str | |||
| k: int = 3 | |||
| conversation_message_task: ConversationMessageTask | |||
| return_resource: str | |||
| retriever_from: str | |||
| @classmethod | |||
| def from_dataset(cls, dataset: Dataset, **kwargs): | |||
| @@ -86,7 +92,7 @@ class DatasetRetrieverTool(BaseTool): | |||
| if self.k > 0: | |||
| documents = vector_index.search( | |||
| query, | |||
| search_type='similarity', | |||
| search_type='similarity_score_threshold', | |||
| search_kwargs={ | |||
| 'k': self.k | |||
| } | |||
| @@ -94,8 +100,12 @@ class DatasetRetrieverTool(BaseTool): | |||
| else: | |||
| documents = [] | |||
| hit_callback = DatasetIndexToolCallbackHandler(dataset.id) | |||
| hit_callback = DatasetIndexToolCallbackHandler(dataset.id, self.conversation_message_task) | |||
| hit_callback.on_tool_end(documents) | |||
| document_score_list = {} | |||
| if dataset.indexing_technique != "economy": | |||
| for item in documents: | |||
| document_score_list[item.metadata['doc_id']] = item.metadata['score'] | |||
| document_context_list = [] | |||
| index_node_ids = [document.metadata['doc_id'] for document in documents] | |||
| segments = DocumentSegment.query.filter(DocumentSegment.dataset_id == self.dataset_id, | |||
| @@ -112,9 +122,43 @@ class DatasetRetrieverTool(BaseTool): | |||
| float('inf'))) | |||
| for segment in sorted_segments: | |||
| if segment.answer: | |||
| document_context_list.append(f'question:{segment.content} \nanswer:{segment.answer}') | |||
| document_context_list.append(f'question:{segment.content} answer:{segment.answer}') | |||
| else: | |||
| document_context_list.append(segment.content) | |||
| if self.return_resource: | |||
| context_list = [] | |||
| resource_number = 1 | |||
| for segment in sorted_segments: | |||
| context = {} | |||
| document = Document.query.filter(Document.id == segment.document_id, | |||
| Document.enabled == True, | |||
| Document.archived == False, | |||
| ).first() | |||
| if dataset and document: | |||
| source = { | |||
| 'position': resource_number, | |||
| 'dataset_id': dataset.id, | |||
| 'dataset_name': dataset.name, | |||
| 'document_id': document.id, | |||
| 'document_name': document.name, | |||
| 'data_source_type': document.data_source_type, | |||
| 'segment_id': segment.id, | |||
| 'retriever_from': self.retriever_from | |||
| } | |||
| if dataset.indexing_technique != "economy": | |||
| source['score'] = document_score_list.get(segment.index_node_id) | |||
| if self.retriever_from == 'dev': | |||
| source['hit_count'] = segment.hit_count | |||
| source['word_count'] = segment.word_count | |||
| source['segment_position'] = segment.position | |||
| source['index_node_hash'] = segment.index_node_hash | |||
| if segment.answer: | |||
| source['content'] = f'question:{segment.content} \nanswer:{segment.answer}' | |||
| else: | |||
| source['content'] = segment.content | |||
| context_list.append(source) | |||
| resource_number += 1 | |||
| hit_callback.return_retriever_resource_info(context_list) | |||
| return str("\n".join(document_context_list)) | |||
| @@ -0,0 +1,54 @@ | |||
| """add_dataset_retriever_resource | |||
| Revision ID: 6dcb43972bdc | |||
| Revises: 4bcffcd64aa4 | |||
| Create Date: 2023-09-06 16:51:27.385844 | |||
| """ | |||
| from alembic import op | |||
| import sqlalchemy as sa | |||
| from sqlalchemy.dialects import postgresql | |||
| # revision identifiers, used by Alembic. | |||
| revision = '6dcb43972bdc' | |||
| down_revision = '4bcffcd64aa4' | |||
| branch_labels = None | |||
| depends_on = None | |||
| def upgrade(): | |||
| # ### commands auto generated by Alembic - please adjust! ### | |||
| op.create_table('dataset_retriever_resources', | |||
| sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), | |||
| sa.Column('message_id', postgresql.UUID(), nullable=False), | |||
| sa.Column('position', sa.Integer(), nullable=False), | |||
| sa.Column('dataset_id', postgresql.UUID(), nullable=False), | |||
| sa.Column('dataset_name', sa.Text(), nullable=False), | |||
| sa.Column('document_id', postgresql.UUID(), nullable=False), | |||
| sa.Column('document_name', sa.Text(), nullable=False), | |||
| sa.Column('data_source_type', sa.Text(), nullable=False), | |||
| sa.Column('segment_id', postgresql.UUID(), nullable=False), | |||
| sa.Column('score', sa.Float(), nullable=True), | |||
| sa.Column('content', sa.Text(), nullable=False), | |||
| sa.Column('hit_count', sa.Integer(), nullable=True), | |||
| sa.Column('word_count', sa.Integer(), nullable=True), | |||
| sa.Column('segment_position', sa.Integer(), nullable=True), | |||
| sa.Column('index_node_hash', sa.Text(), nullable=True), | |||
| sa.Column('retriever_from', sa.Text(), nullable=False), | |||
| sa.Column('created_by', postgresql.UUID(), nullable=False), | |||
| sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), | |||
| sa.PrimaryKeyConstraint('id', name='dataset_retriever_resource_pkey') | |||
| ) | |||
| with op.batch_alter_table('dataset_retriever_resources', schema=None) as batch_op: | |||
| batch_op.create_index('dataset_retriever_resource_message_id_idx', ['message_id'], unique=False) | |||
| # ### end Alembic commands ### | |||
| def downgrade(): | |||
| # ### commands auto generated by Alembic - please adjust! ### | |||
| with op.batch_alter_table('dataset_retriever_resources', schema=None) as batch_op: | |||
| batch_op.drop_index('dataset_retriever_resource_message_id_idx') | |||
| op.drop_table('dataset_retriever_resources') | |||
| # ### end Alembic commands ### | |||
| @@ -0,0 +1,32 @@ | |||
| """add_app_config_retriever_resource | |||
| Revision ID: 77e83833755c | |||
| Revises: 6dcb43972bdc | |||
| Create Date: 2023-09-06 17:26:40.311927 | |||
| """ | |||
| from alembic import op | |||
| import sqlalchemy as sa | |||
| # revision identifiers, used by Alembic. | |||
| revision = '77e83833755c' | |||
| down_revision = '6dcb43972bdc' | |||
| branch_labels = None | |||
| depends_on = None | |||
| def upgrade(): | |||
| # ### commands auto generated by Alembic - please adjust! ### | |||
| with op.batch_alter_table('app_model_configs', schema=None) as batch_op: | |||
| batch_op.add_column(sa.Column('retriever_resource', sa.Text(), nullable=True)) | |||
| # ### end Alembic commands ### | |||
| def downgrade(): | |||
| # ### commands auto generated by Alembic - please adjust! ### | |||
| with op.batch_alter_table('app_model_configs', schema=None) as batch_op: | |||
| batch_op.drop_column('retriever_resource') | |||
| # ### end Alembic commands ### | |||
| @@ -1,4 +1,5 @@ | |||
| import json | |||
| from json import JSONDecodeError | |||
| from flask import current_app, request | |||
| from flask_login import UserMixin | |||
| @@ -90,6 +91,7 @@ class AppModelConfig(db.Model): | |||
| pre_prompt = db.Column(db.Text) | |||
| agent_mode = db.Column(db.Text) | |||
| sensitive_word_avoidance = db.Column(db.Text) | |||
| retriever_resource = db.Column(db.Text) | |||
| @property | |||
| def app(self): | |||
| @@ -114,6 +116,11 @@ class AppModelConfig(db.Model): | |||
| return json.loads(self.speech_to_text) if self.speech_to_text \ | |||
| else {"enabled": False} | |||
| @property | |||
| def retriever_resource_dict(self) -> dict: | |||
| return json.loads(self.retriever_resource) if self.retriever_resource \ | |||
| else {"enabled": False} | |||
| @property | |||
| def more_like_this_dict(self) -> dict: | |||
| return json.loads(self.more_like_this) if self.more_like_this else {"enabled": False} | |||
| @@ -140,6 +147,7 @@ class AppModelConfig(db.Model): | |||
| "suggested_questions": self.suggested_questions_list, | |||
| "suggested_questions_after_answer": self.suggested_questions_after_answer_dict, | |||
| "speech_to_text": self.speech_to_text_dict, | |||
| "retriever_resource": self.retriever_resource, | |||
| "more_like_this": self.more_like_this_dict, | |||
| "sensitive_word_avoidance": self.sensitive_word_avoidance_dict, | |||
| "model": self.model_dict, | |||
| @@ -164,7 +172,8 @@ class AppModelConfig(db.Model): | |||
| self.user_input_form = json.dumps(model_config['user_input_form']) | |||
| self.pre_prompt = model_config['pre_prompt'] | |||
| self.agent_mode = json.dumps(model_config['agent_mode']) | |||
| self.retriever_resource = json.dumps(model_config['retriever_resource']) \ | |||
| if model_config.get('retriever_resource') else None | |||
| return self | |||
| def copy(self): | |||
| @@ -318,6 +327,7 @@ class Conversation(db.Model): | |||
| model_config['suggested_questions'] = app_model_config.suggested_questions_list | |||
| model_config['suggested_questions_after_answer'] = app_model_config.suggested_questions_after_answer_dict | |||
| model_config['speech_to_text'] = app_model_config.speech_to_text_dict | |||
| model_config['retriever_resource'] = app_model_config.retriever_resource_dict | |||
| model_config['more_like_this'] = app_model_config.more_like_this_dict | |||
| model_config['sensitive_word_avoidance'] = app_model_config.sensitive_word_avoidance_dict | |||
| model_config['user_input_form'] = app_model_config.user_input_form_list | |||
| @@ -476,6 +486,11 @@ class Message(db.Model): | |||
| return db.session.query(MessageAgentThought).filter(MessageAgentThought.message_id == self.id) \ | |||
| .order_by(MessageAgentThought.position.asc()).all() | |||
| @property | |||
| def retriever_resources(self): | |||
| return db.session.query(DatasetRetrieverResource).filter(DatasetRetrieverResource.message_id == self.id) \ | |||
| .order_by(DatasetRetrieverResource.position.asc()).all() | |||
| class MessageFeedback(db.Model): | |||
| __tablename__ = 'message_feedbacks' | |||
| @@ -719,3 +734,31 @@ class MessageAgentThought(db.Model): | |||
| created_by_role = db.Column(db.String, nullable=False) | |||
| created_by = db.Column(UUID, nullable=False) | |||
| created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp()) | |||
| class DatasetRetrieverResource(db.Model): | |||
| __tablename__ = 'dataset_retriever_resources' | |||
| __table_args__ = ( | |||
| db.PrimaryKeyConstraint('id', name='dataset_retriever_resource_pkey'), | |||
| db.Index('dataset_retriever_resource_message_id_idx', 'message_id'), | |||
| ) | |||
| id = db.Column(UUID, nullable=False, server_default=db.text('uuid_generate_v4()')) | |||
| message_id = db.Column(UUID, nullable=False) | |||
| position = db.Column(db.Integer, nullable=False) | |||
| dataset_id = db.Column(UUID, nullable=False) | |||
| dataset_name = db.Column(db.Text, nullable=False) | |||
| document_id = db.Column(UUID, nullable=False) | |||
| document_name = db.Column(db.Text, nullable=False) | |||
| data_source_type = db.Column(db.Text, nullable=False) | |||
| segment_id = db.Column(UUID, nullable=False) | |||
| score = db.Column(db.Float, nullable=True) | |||
| content = db.Column(db.Text, nullable=False) | |||
| hit_count = db.Column(db.Integer, nullable=True) | |||
| word_count = db.Column(db.Integer, nullable=True) | |||
| segment_position = db.Column(db.Integer, nullable=True) | |||
| index_node_hash = db.Column(db.Text, nullable=True) | |||
| retriever_from = db.Column(db.Text, nullable=False) | |||
| created_by = db.Column(UUID, nullable=False) | |||
| created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp()) | |||
| @@ -130,6 +130,21 @@ class AppModelConfigService: | |||
| if not isinstance(config["speech_to_text"]["enabled"], bool): | |||
| raise ValueError("enabled in speech_to_text must be of boolean type") | |||
| # return retriever resource | |||
| if 'retriever_resource' not in config or not config["retriever_resource"]: | |||
| config["retriever_resource"] = { | |||
| "enabled": False | |||
| } | |||
| if not isinstance(config["retriever_resource"], dict): | |||
| raise ValueError("retriever_resource must be of dict type") | |||
| if "enabled" not in config["retriever_resource"] or not config["retriever_resource"]["enabled"]: | |||
| config["retriever_resource"]["enabled"] = False | |||
| if not isinstance(config["retriever_resource"]["enabled"], bool): | |||
| raise ValueError("enabled in speech_to_text must be of boolean type") | |||
| # more_like_this | |||
| if 'more_like_this' not in config or not config["more_like_this"]: | |||
| config["more_like_this"] = { | |||
| @@ -327,6 +342,7 @@ class AppModelConfigService: | |||
| "suggested_questions": config["suggested_questions"], | |||
| "suggested_questions_after_answer": config["suggested_questions_after_answer"], | |||
| "speech_to_text": config["speech_to_text"], | |||
| "retriever_resource": config["retriever_resource"], | |||
| "more_like_this": config["more_like_this"], | |||
| "sensitive_word_avoidance": config["sensitive_word_avoidance"], | |||
| "model": { | |||
| @@ -11,7 +11,8 @@ from sqlalchemy import and_ | |||
| from core.completion import Completion | |||
| from core.conversation_message_task import PubHandler, ConversationTaskStoppedException | |||
| from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, LLMRateLimitError, \ | |||
| from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, \ | |||
| LLMRateLimitError, \ | |||
| LLMAuthorizationError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError | |||
| from extensions.ext_database import db | |||
| from extensions.ext_redis import redis_client | |||
| @@ -95,6 +96,7 @@ class CompletionService: | |||
| app_model_config_model = app_model_config.model_dict | |||
| app_model_config_model['completion_params'] = completion_params | |||
| app_model_config.retriever_resource = json.dumps({'enabled': True}) | |||
| app_model_config = app_model_config.copy() | |||
| app_model_config.model = json.dumps(app_model_config_model) | |||
| @@ -145,7 +147,8 @@ class CompletionService: | |||
| 'user': user, | |||
| 'conversation': conversation, | |||
| 'streaming': streaming, | |||
| 'is_model_config_override': is_model_config_override | |||
| 'is_model_config_override': is_model_config_override, | |||
| 'retriever_from': args['retriever_from'] if 'retriever_from' in args else 'dev' | |||
| }) | |||
| generate_worker_thread.start() | |||
| @@ -169,7 +172,8 @@ class CompletionService: | |||
| @classmethod | |||
| def generate_worker(cls, flask_app: Flask, generate_task_id: str, app_model: App, app_model_config: AppModelConfig, | |||
| query: str, inputs: dict, user: Union[Account, EndUser], | |||
| conversation: Conversation, streaming: bool, is_model_config_override: bool): | |||
| conversation: Conversation, streaming: bool, is_model_config_override: bool, | |||
| retriever_from: str = 'dev'): | |||
| with flask_app.app_context(): | |||
| try: | |||
| if conversation: | |||
| @@ -188,6 +192,7 @@ class CompletionService: | |||
| conversation=conversation, | |||
| streaming=streaming, | |||
| is_override=is_model_config_override, | |||
| retriever_from=retriever_from | |||
| ) | |||
| except ConversationTaskStoppedException: | |||
| pass | |||
| @@ -400,7 +405,11 @@ class CompletionService: | |||
| elif event == 'chain': | |||
| yield "data: " + json.dumps(cls.get_chain_response_data(result.get('data'))) + "\n\n" | |||
| elif event == 'agent_thought': | |||
| yield "data: " + json.dumps(cls.get_agent_thought_response_data(result.get('data'))) + "\n\n" | |||
| yield "data: " + json.dumps( | |||
| cls.get_agent_thought_response_data(result.get('data'))) + "\n\n" | |||
| elif event == 'message_end': | |||
| yield "data: " + json.dumps( | |||
| cls.get_message_end_data(result.get('data'))) + "\n\n" | |||
| elif event == 'ping': | |||
| yield "event: ping\n\n" | |||
| else: | |||
| @@ -432,6 +441,20 @@ class CompletionService: | |||
| return response_data | |||
| @classmethod | |||
| def get_message_end_data(cls, data: dict): | |||
| response_data = { | |||
| 'event': 'message_end', | |||
| 'task_id': data.get('task_id'), | |||
| 'id': data.get('message_id') | |||
| } | |||
| if 'retriever_resources' in data: | |||
| response_data['retriever_resources'] = data.get('retriever_resources') | |||
| if data.get('mode') == 'chat': | |||
| response_data['conversation_id'] = data.get('conversation_id') | |||
| return response_data | |||
| @classmethod | |||
| def get_chain_response_data(cls, data: dict): | |||
| response_data = { | |||