Co-authored-by: jyong <jyong@dify.ai>tags/0.3.34
| @@ -28,7 +28,7 @@ from extensions.ext_database import db | |||
| from libs.rsa import generate_key_pair | |||
| from models.account import InvitationCode, Tenant, TenantAccountJoin | |||
| from models.dataset import Dataset, DatasetQuery, Document, DatasetCollectionBinding | |||
| from models.model import Account, AppModelConfig, App | |||
| from models.model import Account, AppModelConfig, App, MessageAnnotation, Message | |||
| import secrets | |||
| import base64 | |||
| @@ -752,6 +752,30 @@ def migrate_default_input_to_dataset_query_variable(batch_size): | |||
| pbar.update(len(data_batch)) | |||
| @click.command('add-annotation-question-field-value', help='add annotation question value') | |||
| def add_annotation_question_field_value(): | |||
| click.echo(click.style('Start add annotation question value.', fg='green')) | |||
| message_annotations = db.session.query(MessageAnnotation).all() | |||
| message_annotation_deal_count = 0 | |||
| if message_annotations: | |||
| for message_annotation in message_annotations: | |||
| try: | |||
| if message_annotation.message_id and not message_annotation.question: | |||
| message = db.session.query(Message).filter( | |||
| Message.id == message_annotation.message_id | |||
| ).first() | |||
| message_annotation.question = message.query | |||
| db.session.add(message_annotation) | |||
| db.session.commit() | |||
| message_annotation_deal_count += 1 | |||
| except Exception as e: | |||
| click.echo( | |||
| click.style('Add annotation question value error: {} {}'.format(e.__class__.__name__, str(e)), | |||
| fg='red')) | |||
| click.echo( | |||
| click.style(f'Congratulations! add annotation question value successful. Deal count {message_annotation_deal_count}', fg='green')) | |||
| def register_commands(app): | |||
| app.cli.add_command(reset_password) | |||
| app.cli.add_command(reset_email) | |||
| @@ -766,3 +790,4 @@ def register_commands(app): | |||
| app.cli.add_command(normalization_collections) | |||
| app.cli.add_command(migrate_default_input_to_dataset_query_variable) | |||
| app.cli.add_command(add_qdrant_full_text_index) | |||
| app.cli.add_command(add_annotation_question_field_value) | |||
| @@ -9,7 +9,7 @@ api = ExternalApi(bp) | |||
| from . import extension, setup, version, apikey, admin | |||
| # Import app controllers | |||
| from .app import advanced_prompt_template, app, site, completion, model_config, statistic, conversation, message, generator, audio | |||
| from .app import advanced_prompt_template, app, site, completion, model_config, statistic, conversation, message, generator, audio, annotation | |||
| # Import auth controllers | |||
| from .auth import login, oauth, data_source_oauth, activate | |||
| @@ -0,0 +1,291 @@ | |||
| from flask_login import current_user | |||
| from flask_restful import Resource, reqparse, marshal_with, marshal | |||
| from werkzeug.exceptions import Forbidden | |||
| from controllers.console import api | |||
| from controllers.console.app.error import NoFileUploadedError | |||
| from controllers.console.datasets.error import TooManyFilesError | |||
| from controllers.console.setup import setup_required | |||
| from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check | |||
| from extensions.ext_redis import redis_client | |||
| from fields.annotation_fields import annotation_list_fields, annotation_hit_history_list_fields, annotation_fields, \ | |||
| annotation_hit_history_fields | |||
| from libs.login import login_required | |||
| from services.annotation_service import AppAnnotationService | |||
| from flask import request | |||
| class AnnotationReplyActionApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @cloud_edition_billing_resource_check('annotation') | |||
| def post(self, app_id, action): | |||
| # The role of the current user in the ta table must be admin or owner | |||
| if current_user.current_tenant.current_role not in ['admin', 'owner']: | |||
| raise Forbidden() | |||
| app_id = str(app_id) | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('score_threshold', required=True, type=float, location='json') | |||
| parser.add_argument('embedding_provider_name', required=True, type=str, location='json') | |||
| parser.add_argument('embedding_model_name', required=True, type=str, location='json') | |||
| args = parser.parse_args() | |||
| if action == 'enable': | |||
| result = AppAnnotationService.enable_app_annotation(args, app_id) | |||
| elif action == 'disable': | |||
| result = AppAnnotationService.disable_app_annotation(app_id) | |||
| else: | |||
| raise ValueError('Unsupported annotation reply action') | |||
| return result, 200 | |||
| class AppAnnotationSettingDetailApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| def get(self, app_id): | |||
| # The role of the current user in the ta table must be admin or owner | |||
| if current_user.current_tenant.current_role not in ['admin', 'owner']: | |||
| raise Forbidden() | |||
| app_id = str(app_id) | |||
| result = AppAnnotationService.get_app_annotation_setting_by_app_id(app_id) | |||
| return result, 200 | |||
| class AppAnnotationSettingUpdateApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| def post(self, app_id, annotation_setting_id): | |||
| # The role of the current user in the ta table must be admin or owner | |||
| if current_user.current_tenant.current_role not in ['admin', 'owner']: | |||
| raise Forbidden() | |||
| app_id = str(app_id) | |||
| annotation_setting_id = str(annotation_setting_id) | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('score_threshold', required=True, type=float, location='json') | |||
| args = parser.parse_args() | |||
| result = AppAnnotationService.update_app_annotation_setting(app_id, annotation_setting_id, args) | |||
| return result, 200 | |||
| class AnnotationReplyActionStatusApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @cloud_edition_billing_resource_check('annotation') | |||
| def get(self, app_id, job_id, action): | |||
| # The role of the current user in the ta table must be admin or owner | |||
| if current_user.current_tenant.current_role not in ['admin', 'owner']: | |||
| raise Forbidden() | |||
| job_id = str(job_id) | |||
| app_annotation_job_key = '{}_app_annotation_job_{}'.format(action, str(job_id)) | |||
| cache_result = redis_client.get(app_annotation_job_key) | |||
| if cache_result is None: | |||
| raise ValueError("The job is not exist.") | |||
| job_status = cache_result.decode() | |||
| error_msg = '' | |||
| if job_status == 'error': | |||
| app_annotation_error_key = '{}_app_annotation_error_{}'.format(action, str(job_id)) | |||
| error_msg = redis_client.get(app_annotation_error_key).decode() | |||
| return { | |||
| 'job_id': job_id, | |||
| 'job_status': job_status, | |||
| 'error_msg': error_msg | |||
| }, 200 | |||
| class AnnotationListApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| def get(self, app_id): | |||
| # The role of the current user in the ta table must be admin or owner | |||
| if current_user.current_tenant.current_role not in ['admin', 'owner']: | |||
| raise Forbidden() | |||
| page = request.args.get('page', default=1, type=int) | |||
| limit = request.args.get('limit', default=20, type=int) | |||
| keyword = request.args.get('keyword', default=None, type=str) | |||
| app_id = str(app_id) | |||
| annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(app_id, page, limit, keyword) | |||
| response = { | |||
| 'data': marshal(annotation_list, annotation_fields), | |||
| 'has_more': len(annotation_list) == limit, | |||
| 'limit': limit, | |||
| 'total': total, | |||
| 'page': page | |||
| } | |||
| return response, 200 | |||
| class AnnotationExportApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| def get(self, app_id): | |||
| # The role of the current user in the ta table must be admin or owner | |||
| if current_user.current_tenant.current_role not in ['admin', 'owner']: | |||
| raise Forbidden() | |||
| app_id = str(app_id) | |||
| annotation_list = AppAnnotationService.export_annotation_list_by_app_id(app_id) | |||
| response = { | |||
| 'data': marshal(annotation_list, annotation_fields) | |||
| } | |||
| return response, 200 | |||
| class AnnotationCreateApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @cloud_edition_billing_resource_check('annotation') | |||
| @marshal_with(annotation_fields) | |||
| def post(self, app_id): | |||
| # The role of the current user in the ta table must be admin or owner | |||
| if current_user.current_tenant.current_role not in ['admin', 'owner']: | |||
| raise Forbidden() | |||
| app_id = str(app_id) | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('question', required=True, type=str, location='json') | |||
| parser.add_argument('answer', required=True, type=str, location='json') | |||
| args = parser.parse_args() | |||
| annotation = AppAnnotationService.insert_app_annotation_directly(args, app_id) | |||
| return annotation | |||
| class AnnotationUpdateDeleteApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @cloud_edition_billing_resource_check('annotation') | |||
| @marshal_with(annotation_fields) | |||
| def post(self, app_id, annotation_id): | |||
| # The role of the current user in the ta table must be admin or owner | |||
| if current_user.current_tenant.current_role not in ['admin', 'owner']: | |||
| raise Forbidden() | |||
| app_id = str(app_id) | |||
| annotation_id = str(annotation_id) | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('question', required=True, type=str, location='json') | |||
| parser.add_argument('answer', required=True, type=str, location='json') | |||
| args = parser.parse_args() | |||
| annotation = AppAnnotationService.update_app_annotation_directly(args, app_id, annotation_id) | |||
| return annotation | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @cloud_edition_billing_resource_check('annotation') | |||
| def delete(self, app_id, annotation_id): | |||
| # The role of the current user in the ta table must be admin or owner | |||
| if current_user.current_tenant.current_role not in ['admin', 'owner']: | |||
| raise Forbidden() | |||
| app_id = str(app_id) | |||
| annotation_id = str(annotation_id) | |||
| AppAnnotationService.delete_app_annotation(app_id, annotation_id) | |||
| return {'result': 'success'}, 200 | |||
| class AnnotationBatchImportApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @cloud_edition_billing_resource_check('annotation') | |||
| def post(self, app_id): | |||
| # The role of the current user in the ta table must be admin or owner | |||
| if current_user.current_tenant.current_role not in ['admin', 'owner']: | |||
| raise Forbidden() | |||
| app_id = str(app_id) | |||
| # get file from request | |||
| file = request.files['file'] | |||
| # check file | |||
| if 'file' not in request.files: | |||
| raise NoFileUploadedError() | |||
| if len(request.files) > 1: | |||
| raise TooManyFilesError() | |||
| # check file type | |||
| if not file.filename.endswith('.csv'): | |||
| raise ValueError("Invalid file type. Only CSV files are allowed") | |||
| return AppAnnotationService.batch_import_app_annotations(app_id, file) | |||
| class AnnotationBatchImportStatusApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @cloud_edition_billing_resource_check('annotation') | |||
| def get(self, app_id, job_id): | |||
| # The role of the current user in the ta table must be admin or owner | |||
| if current_user.current_tenant.current_role not in ['admin', 'owner']: | |||
| raise Forbidden() | |||
| job_id = str(job_id) | |||
| indexing_cache_key = 'app_annotation_batch_import_{}'.format(str(job_id)) | |||
| cache_result = redis_client.get(indexing_cache_key) | |||
| if cache_result is None: | |||
| raise ValueError("The job is not exist.") | |||
| job_status = cache_result.decode() | |||
| error_msg = '' | |||
| if job_status == 'error': | |||
| indexing_error_msg_key = 'app_annotation_batch_import_error_msg_{}'.format(str(job_id)) | |||
| error_msg = redis_client.get(indexing_error_msg_key).decode() | |||
| return { | |||
| 'job_id': job_id, | |||
| 'job_status': job_status, | |||
| 'error_msg': error_msg | |||
| }, 200 | |||
| class AnnotationHitHistoryListApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| def get(self, app_id, annotation_id): | |||
| # The role of the current user in the table must be admin or owner | |||
| if current_user.current_tenant.current_role not in ['admin', 'owner']: | |||
| raise Forbidden() | |||
| page = request.args.get('page', default=1, type=int) | |||
| limit = request.args.get('limit', default=20, type=int) | |||
| app_id = str(app_id) | |||
| annotation_id = str(annotation_id) | |||
| annotation_hit_history_list, total = AppAnnotationService.get_annotation_hit_histories(app_id, annotation_id, | |||
| page, limit) | |||
| response = { | |||
| 'data': marshal(annotation_hit_history_list, annotation_hit_history_fields), | |||
| 'has_more': len(annotation_hit_history_list) == limit, | |||
| 'limit': limit, | |||
| 'total': total, | |||
| 'page': page | |||
| } | |||
| return response | |||
| api.add_resource(AnnotationReplyActionApi, '/apps/<uuid:app_id>/annotation-reply/<string:action>') | |||
| api.add_resource(AnnotationReplyActionStatusApi, | |||
| '/apps/<uuid:app_id>/annotation-reply/<string:action>/status/<uuid:job_id>') | |||
| api.add_resource(AnnotationListApi, '/apps/<uuid:app_id>/annotations') | |||
| api.add_resource(AnnotationExportApi, '/apps/<uuid:app_id>/annotations/export') | |||
| api.add_resource(AnnotationUpdateDeleteApi, '/apps/<uuid:app_id>/annotations/<uuid:annotation_id>') | |||
| api.add_resource(AnnotationBatchImportApi, '/apps/<uuid:app_id>/annotations/batch-import') | |||
| api.add_resource(AnnotationBatchImportStatusApi, '/apps/<uuid:app_id>/annotations/batch-import-status/<uuid:job_id>') | |||
| api.add_resource(AnnotationHitHistoryListApi, '/apps/<uuid:app_id>/annotations/<uuid:annotation_id>/hit-histories') | |||
| api.add_resource(AppAnnotationSettingDetailApi, '/apps/<uuid:app_id>/annotation-setting') | |||
| api.add_resource(AppAnnotationSettingUpdateApi, '/apps/<uuid:app_id>/annotation-settings/<uuid:annotation_setting_id>') | |||
| @@ -72,4 +72,16 @@ class UnsupportedAudioTypeError(BaseHTTPException): | |||
| class ProviderNotSupportSpeechToTextError(BaseHTTPException): | |||
| error_code = 'provider_not_support_speech_to_text' | |||
| description = "Provider not support speech to text." | |||
| code = 400 | |||
| code = 400 | |||
| class NoFileUploadedError(BaseHTTPException): | |||
| error_code = 'no_file_uploaded' | |||
| description = "Please upload your file." | |||
| code = 400 | |||
| class TooManyFilesError(BaseHTTPException): | |||
| error_code = 'too_many_files' | |||
| description = "Only one file is allowed." | |||
| code = 400 | |||
| @@ -6,22 +6,23 @@ from flask import Response, stream_with_context | |||
| from flask_login import current_user | |||
| from flask_restful import Resource, reqparse, marshal_with, fields | |||
| from flask_restful.inputs import int_range | |||
| from werkzeug.exceptions import InternalServerError, NotFound | |||
| from werkzeug.exceptions import InternalServerError, NotFound, Forbidden | |||
| from controllers.console import api | |||
| from controllers.console.app import _get_app | |||
| from controllers.console.app.error import CompletionRequestError, ProviderNotInitializeError, \ | |||
| AppMoreLikeThisDisabledError, ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError | |||
| from controllers.console.setup import setup_required | |||
| from controllers.console.wraps import account_initialization_required | |||
| from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check | |||
| from core.model_providers.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \ | |||
| ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError | |||
| from libs.login import login_required | |||
| from fields.conversation_fields import message_detail_fields | |||
| from fields.conversation_fields import message_detail_fields, annotation_fields | |||
| from libs.helper import uuid_value | |||
| from libs.infinite_scroll_pagination import InfiniteScrollPagination | |||
| from extensions.ext_database import db | |||
| from models.model import MessageAnnotation, Conversation, Message, MessageFeedback | |||
| from services.annotation_service import AppAnnotationService | |||
| from services.completion_service import CompletionService | |||
| from services.errors.app import MoreLikeThisDisabledError | |||
| from services.errors.conversation import ConversationNotExistsError | |||
| @@ -151,44 +152,24 @@ class MessageAnnotationApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @cloud_edition_billing_resource_check('annotation') | |||
| @marshal_with(annotation_fields) | |||
| def post(self, app_id): | |||
| app_id = str(app_id) | |||
| # The role of the current user in the ta table must be admin or owner | |||
| if current_user.current_tenant.current_role not in ['admin', 'owner']: | |||
| raise Forbidden() | |||
| # get app info | |||
| app = _get_app(app_id) | |||
| app_id = str(app_id) | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('message_id', required=True, type=uuid_value, location='json') | |||
| parser.add_argument('content', type=str, location='json') | |||
| parser.add_argument('message_id', required=False, type=uuid_value, location='json') | |||
| parser.add_argument('question', required=True, type=str, location='json') | |||
| parser.add_argument('answer', required=True, type=str, location='json') | |||
| parser.add_argument('annotation_reply', required=False, type=dict, location='json') | |||
| args = parser.parse_args() | |||
| annotation = AppAnnotationService.up_insert_app_annotation_from_message(args, app_id) | |||
| message_id = str(args['message_id']) | |||
| message = db.session.query(Message).filter( | |||
| Message.id == message_id, | |||
| Message.app_id == app.id | |||
| ).first() | |||
| if not message: | |||
| raise NotFound("Message Not Exists.") | |||
| annotation = message.annotation | |||
| if annotation: | |||
| annotation.content = args['content'] | |||
| else: | |||
| annotation = MessageAnnotation( | |||
| app_id=app.id, | |||
| conversation_id=message.conversation_id, | |||
| message_id=message.id, | |||
| content=args['content'], | |||
| account_id=current_user.id | |||
| ) | |||
| db.session.add(annotation) | |||
| db.session.commit() | |||
| return {'result': 'success'} | |||
| return annotation | |||
| class MessageAnnotationCountApi(Resource): | |||
| @@ -24,29 +24,29 @@ class ModelConfigResource(Resource): | |||
| """Modify app model config""" | |||
| app_id = str(app_id) | |||
| app_model = _get_app(app_id) | |||
| app = _get_app(app_id) | |||
| # validate config | |||
| model_configuration = AppModelConfigService.validate_configuration( | |||
| tenant_id=current_user.current_tenant_id, | |||
| account=current_user, | |||
| config=request.json, | |||
| mode=app_model.mode | |||
| mode=app.mode | |||
| ) | |||
| new_app_model_config = AppModelConfig( | |||
| app_id=app_model.id, | |||
| app_id=app.id, | |||
| ) | |||
| new_app_model_config = new_app_model_config.from_model_config_dict(model_configuration) | |||
| db.session.add(new_app_model_config) | |||
| db.session.flush() | |||
| app_model.app_model_config_id = new_app_model_config.id | |||
| app.app_model_config_id = new_app_model_config.id | |||
| db.session.commit() | |||
| app_model_config_was_updated.send( | |||
| app_model, | |||
| app, | |||
| app_model_config=new_app_model_config | |||
| ) | |||
| @@ -30,6 +30,7 @@ class AppParameterApi(InstalledAppResource): | |||
| 'suggested_questions_after_answer': fields.Raw, | |||
| 'speech_to_text': fields.Raw, | |||
| 'retriever_resource': fields.Raw, | |||
| 'annotation_reply': fields.Raw, | |||
| 'more_like_this': fields.Raw, | |||
| 'user_input_form': fields.Raw, | |||
| 'sensitive_word_avoidance': fields.Raw, | |||
| @@ -49,6 +50,7 @@ class AppParameterApi(InstalledAppResource): | |||
| 'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict, | |||
| 'speech_to_text': app_model_config.speech_to_text_dict, | |||
| 'retriever_resource': app_model_config.retriever_resource_dict, | |||
| 'annotation_reply': app_model_config.annotation_reply_dict, | |||
| 'more_like_this': app_model_config.more_like_this_dict, | |||
| 'user_input_form': app_model_config.user_input_form_list, | |||
| 'sensitive_word_avoidance': app_model_config.sensitive_word_avoidance_dict, | |||
| @@ -17,6 +17,7 @@ class UniversalChatParameterApi(UniversalChatResource): | |||
| 'suggested_questions_after_answer': fields.Raw, | |||
| 'speech_to_text': fields.Raw, | |||
| 'retriever_resource': fields.Raw, | |||
| 'annotation_reply': fields.Raw | |||
| } | |||
| @marshal_with(parameters_fields) | |||
| @@ -32,6 +33,7 @@ class UniversalChatParameterApi(UniversalChatResource): | |||
| 'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict, | |||
| 'speech_to_text': app_model_config.speech_to_text_dict, | |||
| 'retriever_resource': app_model_config.retriever_resource_dict, | |||
| 'annotation_reply': app_model_config.annotation_reply_dict, | |||
| } | |||
| @@ -47,6 +47,7 @@ def universal_chat_app_required(view=None): | |||
| suggested_questions=json.dumps([]), | |||
| suggested_questions_after_answer=json.dumps({'enabled': True}), | |||
| speech_to_text=json.dumps({'enabled': True}), | |||
| annotation_reply=json.dumps({'enabled': False}), | |||
| retriever_resource=json.dumps({'enabled': True}), | |||
| more_like_this=None, | |||
| sensitive_word_avoidance=None, | |||
| @@ -55,6 +55,7 @@ def cloud_edition_billing_resource_check(resource: str, | |||
| members = billing_info['members'] | |||
| apps = billing_info['apps'] | |||
| vector_space = billing_info['vector_space'] | |||
| annotation_quota_limit = billing_info['annotation_quota_limit'] | |||
| if resource == 'members' and 0 < members['limit'] <= members['size']: | |||
| abort(403, error_msg) | |||
| @@ -62,6 +63,8 @@ def cloud_edition_billing_resource_check(resource: str, | |||
| abort(403, error_msg) | |||
| elif resource == 'vector_space' and 0 < vector_space['limit'] <= vector_space['size']: | |||
| abort(403, error_msg) | |||
| elif resource == 'annotation' and 0 < annotation_quota_limit['limit'] <= annotation_quota_limit['size']: | |||
| abort(403, error_msg) | |||
| else: | |||
| return view(*args, **kwargs) | |||
| @@ -31,6 +31,7 @@ class AppParameterApi(AppApiResource): | |||
| 'suggested_questions_after_answer': fields.Raw, | |||
| 'speech_to_text': fields.Raw, | |||
| 'retriever_resource': fields.Raw, | |||
| 'annotation_reply': fields.Raw, | |||
| 'more_like_this': fields.Raw, | |||
| 'user_input_form': fields.Raw, | |||
| 'sensitive_word_avoidance': fields.Raw, | |||
| @@ -49,6 +50,7 @@ class AppParameterApi(AppApiResource): | |||
| 'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict, | |||
| 'speech_to_text': app_model_config.speech_to_text_dict, | |||
| 'retriever_resource': app_model_config.retriever_resource_dict, | |||
| 'annotation_reply': app_model_config.annotation_reply_dict, | |||
| 'more_like_this': app_model_config.more_like_this_dict, | |||
| 'user_input_form': app_model_config.user_input_form_list, | |||
| 'sensitive_word_avoidance': app_model_config.sensitive_word_avoidance_dict, | |||
| @@ -30,6 +30,7 @@ class AppParameterApi(WebApiResource): | |||
| 'suggested_questions_after_answer': fields.Raw, | |||
| 'speech_to_text': fields.Raw, | |||
| 'retriever_resource': fields.Raw, | |||
| 'annotation_reply': fields.Raw, | |||
| 'more_like_this': fields.Raw, | |||
| 'user_input_form': fields.Raw, | |||
| 'sensitive_word_avoidance': fields.Raw, | |||
| @@ -48,6 +49,7 @@ class AppParameterApi(WebApiResource): | |||
| 'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict, | |||
| 'speech_to_text': app_model_config.speech_to_text_dict, | |||
| 'retriever_resource': app_model_config.retriever_resource_dict, | |||
| 'annotation_reply': app_model_config.annotation_reply_dict, | |||
| 'more_like_this': app_model_config.more_like_this_dict, | |||
| 'user_input_form': app_model_config.user_input_form_list, | |||
| 'sensitive_word_avoidance': app_model_config.sensitive_word_avoidance_dict, | |||
| @@ -12,8 +12,10 @@ from core.callback_handler.main_chain_gather_callback_handler import MainChainGa | |||
| from core.callback_handler.llm_callback_handler import LLMCallbackHandler | |||
| from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException, \ | |||
| ConversationTaskInterruptException | |||
| from core.embedding.cached_embedding import CacheEmbedding | |||
| from core.external_data_tool.factory import ExternalDataToolFactory | |||
| from core.file.file_obj import FileObj | |||
| from core.index.vector_index.vector_index import VectorIndex | |||
| from core.model_providers.error import LLMBadRequestError | |||
| from core.memory.read_only_conversation_token_db_buffer_shared_memory import \ | |||
| ReadOnlyConversationTokenDBBufferSharedMemory | |||
| @@ -23,9 +25,12 @@ from core.model_providers.models.llm.base import BaseLLM | |||
| from core.orchestrator_rule_parser import OrchestratorRuleParser | |||
| from core.prompt.prompt_template import PromptTemplateParser | |||
| from core.prompt.prompt_transform import PromptTransform | |||
| from models.dataset import Dataset | |||
| from models.model import App, AppModelConfig, Account, Conversation, EndUser | |||
| from core.moderation.base import ModerationException, ModerationAction | |||
| from core.moderation.factory import ModerationFactory | |||
| from services.annotation_service import AppAnnotationService | |||
| from services.dataset_service import DatasetCollectionBindingService | |||
| class Completion: | |||
| @@ -33,7 +38,7 @@ class Completion: | |||
| def generate(cls, task_id: str, app: App, app_model_config: AppModelConfig, query: str, inputs: dict, | |||
| files: List[FileObj], user: Union[Account, EndUser], conversation: Optional[Conversation], | |||
| streaming: bool, is_override: bool = False, retriever_from: str = 'dev', | |||
| auto_generate_name: bool = True): | |||
| auto_generate_name: bool = True, from_source: str = 'console'): | |||
| """ | |||
| errors: ProviderTokenNotInitError | |||
| """ | |||
| @@ -109,7 +114,10 @@ class Completion: | |||
| fake_response=str(e) | |||
| ) | |||
| return | |||
| # check annotation reply | |||
| annotation_reply = cls.query_app_annotations_to_reply(conversation_message_task, from_source) | |||
| if annotation_reply: | |||
| return | |||
| # fill in variable inputs from external data tools if exists | |||
| external_data_tools = app_model_config.external_data_tools_list | |||
| if external_data_tools: | |||
| @@ -166,17 +174,18 @@ class Completion: | |||
| except ChunkedEncodingError as e: | |||
| # Interrupt by LLM (like OpenAI), handle it. | |||
| logging.warning(f'ChunkedEncodingError: {e}') | |||
| conversation_message_task.end() | |||
| return | |||
| @classmethod | |||
| def moderation_for_inputs(cls, app_id: str, tenant_id: str, app_model_config: AppModelConfig, inputs: dict, query: str): | |||
| def moderation_for_inputs(cls, app_id: str, tenant_id: str, app_model_config: AppModelConfig, inputs: dict, | |||
| query: str): | |||
| if not app_model_config.sensitive_word_avoidance_dict['enabled']: | |||
| return inputs, query | |||
| type = app_model_config.sensitive_word_avoidance_dict['type'] | |||
| moderation = ModerationFactory(type, app_id, tenant_id, app_model_config.sensitive_word_avoidance_dict['config']) | |||
| moderation = ModerationFactory(type, app_id, tenant_id, | |||
| app_model_config.sensitive_word_avoidance_dict['config']) | |||
| moderation_result = moderation.moderation_for_inputs(inputs, query) | |||
| if not moderation_result.flagged: | |||
| @@ -324,6 +333,76 @@ class Completion: | |||
| external_context = memory.load_memory_variables({}) | |||
| return external_context[memory_key] | |||
| @classmethod | |||
| def query_app_annotations_to_reply(cls, conversation_message_task: ConversationMessageTask, | |||
| from_source: str) -> bool: | |||
| """Get memory messages.""" | |||
| app_model_config = conversation_message_task.app_model_config | |||
| app = conversation_message_task.app | |||
| annotation_reply = app_model_config.annotation_reply_dict | |||
| if annotation_reply['enabled']: | |||
| score_threshold = annotation_reply.get('score_threshold', 1) | |||
| embedding_provider_name = annotation_reply['embedding_model']['embedding_provider_name'] | |||
| embedding_model_name = annotation_reply['embedding_model']['embedding_model_name'] | |||
| # get embedding model | |||
| embedding_model = ModelFactory.get_embedding_model( | |||
| tenant_id=app.tenant_id, | |||
| model_provider_name=embedding_provider_name, | |||
| model_name=embedding_model_name | |||
| ) | |||
| embeddings = CacheEmbedding(embedding_model) | |||
| dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( | |||
| embedding_provider_name, | |||
| embedding_model_name, | |||
| 'annotation' | |||
| ) | |||
| dataset = Dataset( | |||
| id=app.id, | |||
| tenant_id=app.tenant_id, | |||
| indexing_technique='high_quality', | |||
| embedding_model_provider=embedding_provider_name, | |||
| embedding_model=embedding_model_name, | |||
| collection_binding_id=dataset_collection_binding.id | |||
| ) | |||
| vector_index = VectorIndex( | |||
| dataset=dataset, | |||
| config=current_app.config, | |||
| embeddings=embeddings | |||
| ) | |||
| documents = vector_index.search( | |||
| conversation_message_task.query, | |||
| search_type='similarity_score_threshold', | |||
| search_kwargs={ | |||
| 'k': 1, | |||
| 'score_threshold': score_threshold, | |||
| 'filter': { | |||
| 'group_id': [dataset.id] | |||
| } | |||
| } | |||
| ) | |||
| if documents: | |||
| annotation_id = documents[0].metadata['annotation_id'] | |||
| score = documents[0].metadata['score'] | |||
| annotation = AppAnnotationService.get_annotation_by_id(annotation_id) | |||
| if annotation: | |||
| conversation_message_task.annotation_end(annotation.content, annotation.id, annotation.account.name) | |||
| # insert annotation history | |||
| AppAnnotationService.add_annotation_history(annotation.id, | |||
| app.id, | |||
| annotation.question, | |||
| annotation.content, | |||
| conversation_message_task.query, | |||
| conversation_message_task.user.id, | |||
| conversation_message_task.message.id, | |||
| from_source, | |||
| score) | |||
| return True | |||
| return False | |||
| @classmethod | |||
| def get_memory_from_conversation(cls, tenant_id: str, app_model_config: AppModelConfig, | |||
| conversation: Conversation, | |||
| @@ -319,6 +319,10 @@ class ConversationMessageTask: | |||
| self._pub_handler.pub_message_end(self.retriever_resource) | |||
| self._pub_handler.pub_end() | |||
| def annotation_end(self, text: str, annotation_id: str, annotation_author_name: str): | |||
| self._pub_handler.pub_annotation(text, annotation_id, annotation_author_name, self.start_at) | |||
| self._pub_handler.pub_end() | |||
| class PubHandler: | |||
| def __init__(self, user: Union[Account, EndUser], task_id: str, | |||
| @@ -435,7 +439,7 @@ class PubHandler: | |||
| 'task_id': self._task_id, | |||
| 'message_id': self._message.id, | |||
| 'mode': self._conversation.mode, | |||
| 'conversation_id': self._conversation.id | |||
| 'conversation_id': self._conversation.id, | |||
| } | |||
| } | |||
| if retriever_resource: | |||
| @@ -446,6 +450,30 @@ class PubHandler: | |||
| self.pub_end() | |||
| raise ConversationTaskStoppedException() | |||
| def pub_annotation(self, text: str, annotation_id: str, annotation_author_name: str, start_at: float): | |||
| content = { | |||
| 'event': 'annotation', | |||
| 'data': { | |||
| 'task_id': self._task_id, | |||
| 'message_id': self._message.id, | |||
| 'mode': self._conversation.mode, | |||
| 'conversation_id': self._conversation.id, | |||
| 'text': text, | |||
| 'annotation_id': annotation_id, | |||
| 'annotation_author_name': annotation_author_name | |||
| } | |||
| } | |||
| self._message.answer = text | |||
| self._message.provider_response_latency = time.perf_counter() - start_at | |||
| db.session.commit() | |||
| redis_client.publish(self._channel, json.dumps(content)) | |||
| if self._is_stopped(): | |||
| self.pub_end() | |||
| raise ConversationTaskStoppedException() | |||
| def pub_end(self): | |||
| content = { | |||
| 'event': 'end', | |||
| @@ -32,6 +32,10 @@ class BaseIndex(ABC): | |||
| def delete_by_ids(self, ids: list[str]) -> None: | |||
| raise NotImplementedError | |||
| @abstractmethod | |||
| def delete_by_metadata_field(self, key: str, value: str) -> None: | |||
| raise NotImplementedError | |||
| @abstractmethod | |||
| def delete_by_group_id(self, group_id: str) -> None: | |||
| raise NotImplementedError | |||
| @@ -107,6 +107,9 @@ class KeywordTableIndex(BaseIndex): | |||
| self._save_dataset_keyword_table(keyword_table) | |||
| def delete_by_metadata_field(self, key: str, value: str): | |||
| pass | |||
| def get_retriever(self, **kwargs: Any) -> BaseRetriever: | |||
| return KeywordTableRetriever(index=self, **kwargs) | |||
| @@ -121,6 +121,16 @@ class MilvusVectorIndex(BaseVectorIndex): | |||
| 'filter': f'id in {ids}' | |||
| }) | |||
| def delete_by_metadata_field(self, key: str, value: str): | |||
| vector_store = self._get_vector_store() | |||
| vector_store = cast(self._get_vector_store_class(), vector_store) | |||
| ids = vector_store.get_ids_by_metadata_field(key, value) | |||
| if ids: | |||
| vector_store.del_texts({ | |||
| 'filter': f'id in {ids}' | |||
| }) | |||
| def delete_by_ids(self, doc_ids: list[str]) -> None: | |||
| vector_store = self._get_vector_store() | |||
| @@ -138,6 +138,22 @@ class QdrantVectorIndex(BaseVectorIndex): | |||
| ], | |||
| )) | |||
| def delete_by_metadata_field(self, key: str, value: str): | |||
| vector_store = self._get_vector_store() | |||
| vector_store = cast(self._get_vector_store_class(), vector_store) | |||
| from qdrant_client.http import models | |||
| vector_store.del_texts(models.Filter( | |||
| must=[ | |||
| models.FieldCondition( | |||
| key=f"metadata.{key}", | |||
| match=models.MatchValue(value=value), | |||
| ), | |||
| ], | |||
| )) | |||
| def delete_by_ids(self, ids: list[str]) -> None: | |||
| vector_store = self._get_vector_store() | |||
| @@ -141,6 +141,17 @@ class WeaviateVectorIndex(BaseVectorIndex): | |||
| "valueText": document_id | |||
| }) | |||
| def delete_by_metadata_field(self, key: str, value: str): | |||
| vector_store = self._get_vector_store() | |||
| vector_store = cast(self._get_vector_store_class(), vector_store) | |||
| vector_store.del_texts({ | |||
| "operator": "Equal", | |||
| "path": [key], | |||
| "valueText": value | |||
| }) | |||
| def delete_by_group_id(self, group_id: str): | |||
| if self._is_origin(): | |||
| self.recreate_dataset(self.dataset) | |||
| @@ -30,6 +30,16 @@ class MilvusVectorStore(Milvus): | |||
| else: | |||
| return None | |||
| def get_ids_by_metadata_field(self, key: str, value: str): | |||
| result = self.col.query( | |||
| expr=f'metadata["{key}"] == "{value}"', | |||
| output_fields=["id"] | |||
| ) | |||
| if result: | |||
| return [item["id"] for item in result] | |||
| else: | |||
| return None | |||
| def get_ids_by_doc_ids(self, doc_ids: list): | |||
| result = self.col.query( | |||
| expr=f'metadata["doc_id"] in {doc_ids}', | |||
| @@ -6,13 +6,13 @@ from models.model import AppModelConfig | |||
| @app_model_config_was_updated.connect | |||
| def handle(sender, **kwargs): | |||
| app_model = sender | |||
| app = sender | |||
| app_model_config = kwargs.get('app_model_config') | |||
| dataset_ids = get_dataset_ids_from_model_config(app_model_config) | |||
| app_dataset_joins = db.session.query(AppDatasetJoin).filter( | |||
| AppDatasetJoin.app_id == app_model.id | |||
| AppDatasetJoin.app_id == app.id | |||
| ).all() | |||
| removed_dataset_ids = [] | |||
| @@ -29,14 +29,14 @@ def handle(sender, **kwargs): | |||
| if removed_dataset_ids: | |||
| for dataset_id in removed_dataset_ids: | |||
| db.session.query(AppDatasetJoin).filter( | |||
| AppDatasetJoin.app_id == app_model.id, | |||
| AppDatasetJoin.app_id == app.id, | |||
| AppDatasetJoin.dataset_id == dataset_id | |||
| ).delete() | |||
| if added_dataset_ids: | |||
| for dataset_id in added_dataset_ids: | |||
| app_dataset_join = AppDatasetJoin( | |||
| app_id=app_model.id, | |||
| app_id=app.id, | |||
| dataset_id=dataset_id | |||
| ) | |||
| db.session.add(app_dataset_join) | |||
| @@ -0,0 +1,36 @@ | |||
| from flask_restful import fields | |||
| from libs.helper import TimestampField | |||
| account_fields = { | |||
| 'id': fields.String, | |||
| 'name': fields.String, | |||
| 'email': fields.String | |||
| } | |||
| annotation_fields = { | |||
| "id": fields.String, | |||
| "question": fields.String, | |||
| "answer": fields.Raw(attribute='content'), | |||
| "hit_count": fields.Integer, | |||
| "created_at": TimestampField, | |||
| # 'account': fields.Nested(account_fields, allow_null=True) | |||
| } | |||
| annotation_list_fields = { | |||
| "data": fields.List(fields.Nested(annotation_fields)), | |||
| } | |||
| annotation_hit_history_fields = { | |||
| "id": fields.String, | |||
| "source": fields.String, | |||
| "score": fields.Float, | |||
| "question": fields.String, | |||
| "created_at": TimestampField, | |||
| "match": fields.String(attribute='annotation_question'), | |||
| "response": fields.String(attribute='annotation_content') | |||
| } | |||
| annotation_hit_history_list_fields = { | |||
| "data": fields.List(fields.Nested(annotation_hit_history_fields)), | |||
| } | |||
| @@ -21,6 +21,7 @@ model_config_fields = { | |||
| 'suggested_questions_after_answer': fields.Raw(attribute='suggested_questions_after_answer_dict'), | |||
| 'speech_to_text': fields.Raw(attribute='speech_to_text_dict'), | |||
| 'retriever_resource': fields.Raw(attribute='retriever_resource_dict'), | |||
| 'annotation_reply': fields.Raw(attribute='annotation_reply_dict'), | |||
| 'more_like_this': fields.Raw(attribute='more_like_this_dict'), | |||
| 'sensitive_word_avoidance': fields.Raw(attribute='sensitive_word_avoidance_dict'), | |||
| 'external_data_tools': fields.Raw(attribute='external_data_tools_list'), | |||
| @@ -23,11 +23,18 @@ feedback_fields = { | |||
| } | |||
| annotation_fields = { | |||
| 'id': fields.String, | |||
| 'question': fields.String, | |||
| 'content': fields.String, | |||
| 'account': fields.Nested(account_fields, allow_null=True), | |||
| 'created_at': TimestampField | |||
| } | |||
| annotation_hit_history_fields = { | |||
| 'annotation_id': fields.String, | |||
| 'annotation_create_account': fields.Nested(account_fields, allow_null=True) | |||
| } | |||
| message_file_fields = { | |||
| 'id': fields.String, | |||
| 'type': fields.String, | |||
| @@ -49,6 +56,7 @@ message_detail_fields = { | |||
| 'from_account_id': fields.String, | |||
| 'feedbacks': fields.List(fields.Nested(feedback_fields)), | |||
| 'annotation': fields.Nested(annotation_fields, allow_null=True), | |||
| 'annotation_hit_history': fields.Nested(annotation_hit_history_fields, allow_null=True), | |||
| 'created_at': TimestampField, | |||
| 'message_files': fields.List(fields.Nested(message_file_fields), attribute='files'), | |||
| } | |||
| @@ -0,0 +1,50 @@ | |||
| """add_app_anntation_setting | |||
| Revision ID: 246ba09cbbdb | |||
| Revises: 714aafe25d39 | |||
| Create Date: 2023-12-14 11:26:12.287264 | |||
| """ | |||
| from alembic import op | |||
| import sqlalchemy as sa | |||
| from sqlalchemy.dialects import postgresql | |||
| # revision identifiers, used by Alembic. | |||
| revision = '246ba09cbbdb' | |||
| down_revision = '714aafe25d39' | |||
| branch_labels = None | |||
| depends_on = None | |||
| def upgrade(): | |||
| # ### commands auto generated by Alembic - please adjust! ### | |||
| op.create_table('app_annotation_settings', | |||
| sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), | |||
| sa.Column('app_id', postgresql.UUID(), nullable=False), | |||
| sa.Column('score_threshold', sa.Float(), server_default=sa.text('0'), nullable=False), | |||
| sa.Column('collection_binding_id', postgresql.UUID(), nullable=False), | |||
| sa.Column('created_user_id', postgresql.UUID(), nullable=False), | |||
| sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), | |||
| sa.Column('updated_user_id', postgresql.UUID(), nullable=False), | |||
| sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), | |||
| sa.PrimaryKeyConstraint('id', name='app_annotation_settings_pkey') | |||
| ) | |||
| with op.batch_alter_table('app_annotation_settings', schema=None) as batch_op: | |||
| batch_op.create_index('app_annotation_settings_app_idx', ['app_id'], unique=False) | |||
| with op.batch_alter_table('app_model_configs', schema=None) as batch_op: | |||
| batch_op.drop_column('annotation_reply') | |||
| # ### end Alembic commands ### | |||
| def downgrade(): | |||
| # ### commands auto generated by Alembic - please adjust! ### | |||
| with op.batch_alter_table('app_model_configs', schema=None) as batch_op: | |||
| batch_op.add_column(sa.Column('annotation_reply', sa.TEXT(), autoincrement=False, nullable=True)) | |||
| with op.batch_alter_table('app_annotation_settings', schema=None) as batch_op: | |||
| batch_op.drop_index('app_annotation_settings_app_idx') | |||
| op.drop_table('app_annotation_settings') | |||
| # ### end Alembic commands ### | |||
| @@ -0,0 +1,32 @@ | |||
| """add-annotation-histoiry-score | |||
| Revision ID: 46976cc39132 | |||
| Revises: e1901f623fd0 | |||
| Create Date: 2023-12-13 04:39:59.302971 | |||
| """ | |||
| from alembic import op | |||
| import sqlalchemy as sa | |||
| # revision identifiers, used by Alembic. | |||
| revision = '46976cc39132' | |||
| down_revision = 'e1901f623fd0' | |||
| branch_labels = None | |||
| depends_on = None | |||
| def upgrade(): | |||
| # ### commands auto generated by Alembic - please adjust! ### | |||
| with op.batch_alter_table('app_annotation_hit_histories', schema=None) as batch_op: | |||
| batch_op.add_column(sa.Column('score', sa.Float(), server_default=sa.text('0'), nullable=False)) | |||
| # ### end Alembic commands ### | |||
| def downgrade(): | |||
| # ### commands auto generated by Alembic - please adjust! ### | |||
| with op.batch_alter_table('app_annotation_hit_histories', schema=None) as batch_op: | |||
| batch_op.drop_column('score') | |||
| # ### end Alembic commands ### | |||
| @@ -0,0 +1,34 @@ | |||
| """add_anntation_history_match_response | |||
| Revision ID: 714aafe25d39 | |||
| Revises: f2a6fc85e260 | |||
| Create Date: 2023-12-14 06:38:02.972527 | |||
| """ | |||
| from alembic import op | |||
| import sqlalchemy as sa | |||
| # revision identifiers, used by Alembic. | |||
| revision = '714aafe25d39' | |||
| down_revision = 'f2a6fc85e260' | |||
| branch_labels = None | |||
| depends_on = None | |||
| def upgrade(): | |||
| # ### commands auto generated by Alembic - please adjust! ### | |||
| with op.batch_alter_table('app_annotation_hit_histories', schema=None) as batch_op: | |||
| batch_op.add_column(sa.Column('annotation_question', sa.Text(), nullable=False)) | |||
| batch_op.add_column(sa.Column('annotation_content', sa.Text(), nullable=False)) | |||
| # ### end Alembic commands ### | |||
| def downgrade(): | |||
| # ### commands auto generated by Alembic - please adjust! ### | |||
| with op.batch_alter_table('app_annotation_hit_histories', schema=None) as batch_op: | |||
| batch_op.drop_column('annotation_content') | |||
| batch_op.drop_column('annotation_question') | |||
| # ### end Alembic commands ### | |||
| @@ -0,0 +1,79 @@ | |||
| """add-annotation-reply | |||
| Revision ID: e1901f623fd0 | |||
| Revises: fca025d3b60f | |||
| Create Date: 2023-12-12 06:58:41.054544 | |||
| """ | |||
| from alembic import op | |||
| import sqlalchemy as sa | |||
| from sqlalchemy.dialects import postgresql | |||
| # revision identifiers, used by Alembic. | |||
| revision = 'e1901f623fd0' | |||
| down_revision = 'fca025d3b60f' | |||
| branch_labels = None | |||
| depends_on = None | |||
| def upgrade(): | |||
| # ### commands auto generated by Alembic - please adjust! ### | |||
| op.create_table('app_annotation_hit_histories', | |||
| sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), | |||
| sa.Column('app_id', postgresql.UUID(), nullable=False), | |||
| sa.Column('annotation_id', postgresql.UUID(), nullable=False), | |||
| sa.Column('source', sa.Text(), nullable=False), | |||
| sa.Column('question', sa.Text(), nullable=False), | |||
| sa.Column('account_id', postgresql.UUID(), nullable=False), | |||
| sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), | |||
| sa.PrimaryKeyConstraint('id', name='app_annotation_hit_histories_pkey') | |||
| ) | |||
| with op.batch_alter_table('app_annotation_hit_histories', schema=None) as batch_op: | |||
| batch_op.create_index('app_annotation_hit_histories_account_idx', ['account_id'], unique=False) | |||
| batch_op.create_index('app_annotation_hit_histories_annotation_idx', ['annotation_id'], unique=False) | |||
| batch_op.create_index('app_annotation_hit_histories_app_idx', ['app_id'], unique=False) | |||
| with op.batch_alter_table('app_model_configs', schema=None) as batch_op: | |||
| batch_op.add_column(sa.Column('annotation_reply', sa.Text(), nullable=True)) | |||
| with op.batch_alter_table('dataset_collection_bindings', schema=None) as batch_op: | |||
| batch_op.add_column(sa.Column('type', sa.String(length=40), server_default=sa.text("'dataset'::character varying"), nullable=False)) | |||
| with op.batch_alter_table('message_annotations', schema=None) as batch_op: | |||
| batch_op.add_column(sa.Column('question', sa.Text(), nullable=True)) | |||
| batch_op.add_column(sa.Column('hit_count', sa.Integer(), server_default=sa.text('0'), nullable=False)) | |||
| batch_op.alter_column('conversation_id', | |||
| existing_type=postgresql.UUID(), | |||
| nullable=True) | |||
| batch_op.alter_column('message_id', | |||
| existing_type=postgresql.UUID(), | |||
| nullable=True) | |||
| # ### end Alembic commands ### | |||
| def downgrade(): | |||
| # ### commands auto generated by Alembic - please adjust! ### | |||
| with op.batch_alter_table('message_annotations', schema=None) as batch_op: | |||
| batch_op.alter_column('message_id', | |||
| existing_type=postgresql.UUID(), | |||
| nullable=False) | |||
| batch_op.alter_column('conversation_id', | |||
| existing_type=postgresql.UUID(), | |||
| nullable=False) | |||
| batch_op.drop_column('hit_count') | |||
| batch_op.drop_column('question') | |||
| with op.batch_alter_table('dataset_collection_bindings', schema=None) as batch_op: | |||
| batch_op.drop_column('type') | |||
| with op.batch_alter_table('app_model_configs', schema=None) as batch_op: | |||
| batch_op.drop_column('annotation_reply') | |||
| with op.batch_alter_table('app_annotation_hit_histories', schema=None) as batch_op: | |||
| batch_op.drop_index('app_annotation_hit_histories_app_idx') | |||
| batch_op.drop_index('app_annotation_hit_histories_annotation_idx') | |||
| batch_op.drop_index('app_annotation_hit_histories_account_idx') | |||
| op.drop_table('app_annotation_hit_histories') | |||
| # ### end Alembic commands ### | |||
| @@ -0,0 +1,34 @@ | |||
| """add_anntation_history_message_id | |||
| Revision ID: f2a6fc85e260 | |||
| Revises: 46976cc39132 | |||
| Create Date: 2023-12-13 11:09:29.329584 | |||
| """ | |||
| from alembic import op | |||
| import sqlalchemy as sa | |||
| from sqlalchemy.dialects import postgresql | |||
| # revision identifiers, used by Alembic. | |||
| revision = 'f2a6fc85e260' | |||
| down_revision = '46976cc39132' | |||
| branch_labels = None | |||
| depends_on = None | |||
| def upgrade(): | |||
| # ### commands auto generated by Alembic - please adjust! ### | |||
| with op.batch_alter_table('app_annotation_hit_histories', schema=None) as batch_op: | |||
| batch_op.add_column(sa.Column('message_id', postgresql.UUID(), nullable=False)) | |||
| batch_op.create_index('app_annotation_hit_histories_message_idx', ['message_id'], unique=False) | |||
| # ### end Alembic commands ### | |||
| def downgrade(): | |||
| # ### commands auto generated by Alembic - please adjust! ### | |||
| with op.batch_alter_table('app_annotation_hit_histories', schema=None) as batch_op: | |||
| batch_op.drop_index('app_annotation_hit_histories_message_idx') | |||
| batch_op.drop_column('message_id') | |||
| # ### end Alembic commands ### | |||
| @@ -475,5 +475,6 @@ class DatasetCollectionBinding(db.Model): | |||
| id = db.Column(UUID, primary_key=True, server_default=db.text('uuid_generate_v4()')) | |||
| provider_name = db.Column(db.String(40), nullable=False) | |||
| model_name = db.Column(db.String(40), nullable=False) | |||
| type = db.Column(db.String(40), server_default=db.text("'dataset'::character varying"), nullable=False) | |||
| collection_name = db.Column(db.String(64), nullable=False) | |||
| created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) | |||
| @@ -2,6 +2,7 @@ import json | |||
| from flask import current_app, request | |||
| from flask_login import UserMixin | |||
| from sqlalchemy import Float | |||
| from sqlalchemy.dialects.postgresql import UUID | |||
| from core.file.upload_file_parser import UploadFileParser | |||
| @@ -128,6 +129,25 @@ class AppModelConfig(db.Model): | |||
| return json.loads(self.retriever_resource) if self.retriever_resource \ | |||
| else {"enabled": False} | |||
| @property | |||
| def annotation_reply_dict(self) -> dict: | |||
| annotation_setting = db.session.query(AppAnnotationSetting).filter( | |||
| AppAnnotationSetting.app_id == self.app_id).first() | |||
| if annotation_setting: | |||
| collection_binding_detail = annotation_setting.collection_binding_detail | |||
| return { | |||
| "id": annotation_setting.id, | |||
| "enabled": True, | |||
| "score_threshold": annotation_setting.score_threshold, | |||
| "embedding_model": { | |||
| "embedding_provider_name": collection_binding_detail.provider_name, | |||
| "embedding_model_name": collection_binding_detail.model_name | |||
| } | |||
| } | |||
| else: | |||
| return {"enabled": False} | |||
| @property | |||
| def more_like_this_dict(self) -> dict: | |||
| return json.loads(self.more_like_this) if self.more_like_this else {"enabled": False} | |||
| @@ -170,7 +190,9 @@ class AppModelConfig(db.Model): | |||
| @property | |||
| def file_upload_dict(self) -> dict: | |||
| return json.loads(self.file_upload) if self.file_upload else {"image": {"enabled": False, "number_limits": 3, "detail": "high", "transfer_methods": ["remote_url", "local_file"]}} | |||
| return json.loads(self.file_upload) if self.file_upload else { | |||
| "image": {"enabled": False, "number_limits": 3, "detail": "high", | |||
| "transfer_methods": ["remote_url", "local_file"]}} | |||
| def to_dict(self) -> dict: | |||
| return { | |||
| @@ -182,6 +204,7 @@ class AppModelConfig(db.Model): | |||
| "suggested_questions_after_answer": self.suggested_questions_after_answer_dict, | |||
| "speech_to_text": self.speech_to_text_dict, | |||
| "retriever_resource": self.retriever_resource_dict, | |||
| "annotation_reply": self.annotation_reply_dict, | |||
| "more_like_this": self.more_like_this_dict, | |||
| "sensitive_word_avoidance": self.sensitive_word_avoidance_dict, | |||
| "external_data_tools": self.external_data_tools_list, | |||
| @@ -504,6 +527,12 @@ class Message(db.Model): | |||
| annotation = db.session.query(MessageAnnotation).filter(MessageAnnotation.message_id == self.id).first() | |||
| return annotation | |||
| @property | |||
| def annotation_hit_history(self): | |||
| annotation_history = (db.session.query(AppAnnotationHitHistory) | |||
| .filter(AppAnnotationHitHistory.message_id == self.id).first()) | |||
| return annotation_history | |||
| @property | |||
| def app_model_config(self): | |||
| conversation = db.session.query(Conversation).filter(Conversation.id == self.conversation_id).first() | |||
| @@ -616,9 +645,11 @@ class MessageAnnotation(db.Model): | |||
| id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) | |||
| app_id = db.Column(UUID, nullable=False) | |||
| conversation_id = db.Column(UUID, db.ForeignKey('conversations.id'), nullable=False) | |||
| message_id = db.Column(UUID, nullable=False) | |||
| conversation_id = db.Column(UUID, db.ForeignKey('conversations.id'), nullable=True) | |||
| message_id = db.Column(UUID, nullable=True) | |||
| question = db.Column(db.Text, nullable=True) | |||
| content = db.Column(db.Text, nullable=False) | |||
| hit_count = db.Column(db.Integer, nullable=False, server_default=db.text('0')) | |||
| account_id = db.Column(UUID, nullable=False) | |||
| created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) | |||
| updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) | |||
| @@ -629,6 +660,79 @@ class MessageAnnotation(db.Model): | |||
| return account | |||
| class AppAnnotationHitHistory(db.Model): | |||
| __tablename__ = 'app_annotation_hit_histories' | |||
| __table_args__ = ( | |||
| db.PrimaryKeyConstraint('id', name='app_annotation_hit_histories_pkey'), | |||
| db.Index('app_annotation_hit_histories_app_idx', 'app_id'), | |||
| db.Index('app_annotation_hit_histories_account_idx', 'account_id'), | |||
| db.Index('app_annotation_hit_histories_annotation_idx', 'annotation_id'), | |||
| db.Index('app_annotation_hit_histories_message_idx', 'message_id'), | |||
| ) | |||
| id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) | |||
| app_id = db.Column(UUID, nullable=False) | |||
| annotation_id = db.Column(UUID, nullable=False) | |||
| source = db.Column(db.Text, nullable=False) | |||
| question = db.Column(db.Text, nullable=False) | |||
| account_id = db.Column(UUID, nullable=False) | |||
| created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) | |||
| score = db.Column(Float, nullable=False, server_default=db.text('0')) | |||
| message_id = db.Column(UUID, nullable=False) | |||
| annotation_question = db.Column(db.Text, nullable=False) | |||
| annotation_content = db.Column(db.Text, nullable=False) | |||
| @property | |||
| def account(self): | |||
| account = (db.session.query(Account) | |||
| .join(MessageAnnotation, MessageAnnotation.account_id == Account.id) | |||
| .filter(MessageAnnotation.id == self.annotation_id).first()) | |||
| return account | |||
| @property | |||
| def annotation_create_account(self): | |||
| account = db.session.query(Account).filter(Account.id == self.account_id).first() | |||
| return account | |||
| class AppAnnotationSetting(db.Model): | |||
| __tablename__ = 'app_annotation_settings' | |||
| __table_args__ = ( | |||
| db.PrimaryKeyConstraint('id', name='app_annotation_settings_pkey'), | |||
| db.Index('app_annotation_settings_app_idx', 'app_id') | |||
| ) | |||
| id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) | |||
| app_id = db.Column(UUID, nullable=False) | |||
| score_threshold = db.Column(Float, nullable=False, server_default=db.text('0')) | |||
| collection_binding_id = db.Column(UUID, nullable=False) | |||
| created_user_id = db.Column(UUID, nullable=False) | |||
| created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) | |||
| updated_user_id = db.Column(UUID, nullable=False) | |||
| updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) | |||
| @property | |||
| def created_account(self): | |||
| account = (db.session.query(Account) | |||
| .join(AppAnnotationSetting, AppAnnotationSetting.created_user_id == Account.id) | |||
| .filter(AppAnnotationSetting.id == self.annotation_id).first()) | |||
| return account | |||
| @property | |||
| def updated_account(self): | |||
| account = (db.session.query(Account) | |||
| .join(AppAnnotationSetting, AppAnnotationSetting.updated_user_id == Account.id) | |||
| .filter(AppAnnotationSetting.id == self.annotation_id).first()) | |||
| return account | |||
| @property | |||
| def collection_binding_detail(self): | |||
| from .dataset import DatasetCollectionBinding | |||
| collection_binding_detail = (db.session.query(DatasetCollectionBinding) | |||
| .filter(DatasetCollectionBinding.id == self.collection_binding_id).first()) | |||
| return collection_binding_detail | |||
| class OperationLog(db.Model): | |||
| __tablename__ = 'operation_logs' | |||
| __table_args__ = ( | |||
| @@ -0,0 +1,426 @@ | |||
| import datetime | |||
| import json | |||
| import uuid | |||
| import pandas as pd | |||
| from flask_login import current_user | |||
| from sqlalchemy import or_ | |||
| from werkzeug.datastructures import FileStorage | |||
| from werkzeug.exceptions import NotFound | |||
| from extensions.ext_database import db | |||
| from extensions.ext_redis import redis_client | |||
| from models.model import MessageAnnotation, Message, App, AppAnnotationHitHistory, AppAnnotationSetting | |||
| from tasks.annotation.add_annotation_to_index_task import add_annotation_to_index_task | |||
| from tasks.annotation.enable_annotation_reply_task import enable_annotation_reply_task | |||
| from tasks.annotation.disable_annotation_reply_task import disable_annotation_reply_task | |||
| from tasks.annotation.update_annotation_to_index_task import update_annotation_to_index_task | |||
| from tasks.annotation.delete_annotation_index_task import delete_annotation_index_task | |||
| from tasks.annotation.batch_import_annotations_task import batch_import_annotations_task | |||
| class AppAnnotationService: | |||
| @classmethod | |||
| def up_insert_app_annotation_from_message(cls, args: dict, app_id: str) -> MessageAnnotation: | |||
| # get app info | |||
| app = db.session.query(App).filter( | |||
| App.id == app_id, | |||
| App.tenant_id == current_user.current_tenant_id, | |||
| App.status == 'normal' | |||
| ).first() | |||
| if not app: | |||
| raise NotFound("App not found") | |||
| if 'message_id' in args and args['message_id']: | |||
| message_id = str(args['message_id']) | |||
| # get message info | |||
| message = db.session.query(Message).filter( | |||
| Message.id == message_id, | |||
| Message.app_id == app.id | |||
| ).first() | |||
| if not message: | |||
| raise NotFound("Message Not Exists.") | |||
| annotation = message.annotation | |||
| # save the message annotation | |||
| if annotation: | |||
| annotation.content = args['answer'] | |||
| annotation.question = args['question'] | |||
| else: | |||
| annotation = MessageAnnotation( | |||
| app_id=app.id, | |||
| conversation_id=message.conversation_id, | |||
| message_id=message.id, | |||
| content=args['answer'], | |||
| question=args['question'], | |||
| account_id=current_user.id | |||
| ) | |||
| else: | |||
| annotation = MessageAnnotation( | |||
| app_id=app.id, | |||
| content=args['answer'], | |||
| question=args['question'], | |||
| account_id=current_user.id | |||
| ) | |||
| db.session.add(annotation) | |||
| db.session.commit() | |||
| # if annotation reply is enabled , add annotation to index | |||
| annotation_setting = db.session.query(AppAnnotationSetting).filter( | |||
| AppAnnotationSetting.app_id == app_id).first() | |||
| if annotation_setting: | |||
| add_annotation_to_index_task.delay(annotation.id, args['question'], current_user.current_tenant_id, | |||
| app_id, annotation_setting.collection_binding_id) | |||
| return annotation | |||
| @classmethod | |||
| def enable_app_annotation(cls, args: dict, app_id: str) -> dict: | |||
| enable_app_annotation_key = 'enable_app_annotation_{}'.format(str(app_id)) | |||
| cache_result = redis_client.get(enable_app_annotation_key) | |||
| if cache_result is not None: | |||
| return { | |||
| 'job_id': cache_result, | |||
| 'job_status': 'processing' | |||
| } | |||
| # async job | |||
| job_id = str(uuid.uuid4()) | |||
| enable_app_annotation_job_key = 'enable_app_annotation_job_{}'.format(str(job_id)) | |||
| # send batch add segments task | |||
| redis_client.setnx(enable_app_annotation_job_key, 'waiting') | |||
| enable_annotation_reply_task.delay(str(job_id), app_id, current_user.id, current_user.current_tenant_id, | |||
| args['score_threshold'], | |||
| args['embedding_provider_name'], args['embedding_model_name']) | |||
| return { | |||
| 'job_id': job_id, | |||
| 'job_status': 'waiting' | |||
| } | |||
| @classmethod | |||
| def disable_app_annotation(cls, app_id: str) -> dict: | |||
| disable_app_annotation_key = 'disable_app_annotation_{}'.format(str(app_id)) | |||
| cache_result = redis_client.get(disable_app_annotation_key) | |||
| if cache_result is not None: | |||
| return { | |||
| 'job_id': cache_result, | |||
| 'job_status': 'processing' | |||
| } | |||
| # async job | |||
| job_id = str(uuid.uuid4()) | |||
| disable_app_annotation_job_key = 'disable_app_annotation_job_{}'.format(str(job_id)) | |||
| # send batch add segments task | |||
| redis_client.setnx(disable_app_annotation_job_key, 'waiting') | |||
| disable_annotation_reply_task.delay(str(job_id), app_id, current_user.current_tenant_id) | |||
| return { | |||
| 'job_id': job_id, | |||
| 'job_status': 'waiting' | |||
| } | |||
| @classmethod | |||
| def get_annotation_list_by_app_id(cls, app_id: str, page: int, limit: int, keyword: str): | |||
| # get app info | |||
| app = db.session.query(App).filter( | |||
| App.id == app_id, | |||
| App.tenant_id == current_user.current_tenant_id, | |||
| App.status == 'normal' | |||
| ).first() | |||
| if not app: | |||
| raise NotFound("App not found") | |||
| if keyword: | |||
| annotations = (db.session.query(MessageAnnotation) | |||
| .filter(MessageAnnotation.app_id == app_id) | |||
| .filter( | |||
| or_( | |||
| MessageAnnotation.question.ilike('%{}%'.format(keyword)), | |||
| MessageAnnotation.content.ilike('%{}%'.format(keyword)) | |||
| ) | |||
| ) | |||
| .order_by(MessageAnnotation.created_at.desc()) | |||
| .paginate(page=page, per_page=limit, max_per_page=100, error_out=False)) | |||
| else: | |||
| annotations = (db.session.query(MessageAnnotation) | |||
| .filter(MessageAnnotation.app_id == app_id) | |||
| .order_by(MessageAnnotation.created_at.desc()) | |||
| .paginate(page=page, per_page=limit, max_per_page=100, error_out=False)) | |||
| return annotations.items, annotations.total | |||
| @classmethod | |||
| def export_annotation_list_by_app_id(cls, app_id: str): | |||
| # get app info | |||
| app = db.session.query(App).filter( | |||
| App.id == app_id, | |||
| App.tenant_id == current_user.current_tenant_id, | |||
| App.status == 'normal' | |||
| ).first() | |||
| if not app: | |||
| raise NotFound("App not found") | |||
| annotations = (db.session.query(MessageAnnotation) | |||
| .filter(MessageAnnotation.app_id == app_id) | |||
| .order_by(MessageAnnotation.created_at.desc()).all()) | |||
| return annotations | |||
| @classmethod | |||
| def insert_app_annotation_directly(cls, args: dict, app_id: str) -> MessageAnnotation: | |||
| # get app info | |||
| app = db.session.query(App).filter( | |||
| App.id == app_id, | |||
| App.tenant_id == current_user.current_tenant_id, | |||
| App.status == 'normal' | |||
| ).first() | |||
| if not app: | |||
| raise NotFound("App not found") | |||
| annotation = MessageAnnotation( | |||
| app_id=app.id, | |||
| content=args['answer'], | |||
| question=args['question'], | |||
| account_id=current_user.id | |||
| ) | |||
| db.session.add(annotation) | |||
| db.session.commit() | |||
| # if annotation reply is enabled , add annotation to index | |||
| annotation_setting = db.session.query(AppAnnotationSetting).filter( | |||
| AppAnnotationSetting.app_id == app_id).first() | |||
| if annotation_setting: | |||
| add_annotation_to_index_task.delay(annotation.id, args['question'], current_user.current_tenant_id, | |||
| app_id, annotation_setting.collection_binding_id) | |||
| return annotation | |||
| @classmethod | |||
| def update_app_annotation_directly(cls, args: dict, app_id: str, annotation_id: str): | |||
| # get app info | |||
| app = db.session.query(App).filter( | |||
| App.id == app_id, | |||
| App.tenant_id == current_user.current_tenant_id, | |||
| App.status == 'normal' | |||
| ).first() | |||
| if not app: | |||
| raise NotFound("App not found") | |||
| annotation = db.session.query(MessageAnnotation).filter(MessageAnnotation.id == annotation_id).first() | |||
| if not annotation: | |||
| raise NotFound("Annotation not found") | |||
| annotation.content = args['answer'] | |||
| annotation.question = args['question'] | |||
| db.session.commit() | |||
| # if annotation reply is enabled , add annotation to index | |||
| app_annotation_setting = db.session.query(AppAnnotationSetting).filter( | |||
| AppAnnotationSetting.app_id == app_id | |||
| ).first() | |||
| if app_annotation_setting: | |||
| update_annotation_to_index_task.delay(annotation.id, annotation.question, | |||
| current_user.current_tenant_id, | |||
| app_id, app_annotation_setting.collection_binding_id) | |||
| return annotation | |||
| @classmethod | |||
| def delete_app_annotation(cls, app_id: str, annotation_id: str): | |||
| # get app info | |||
| app = db.session.query(App).filter( | |||
| App.id == app_id, | |||
| App.tenant_id == current_user.current_tenant_id, | |||
| App.status == 'normal' | |||
| ).first() | |||
| if not app: | |||
| raise NotFound("App not found") | |||
| annotation = db.session.query(MessageAnnotation).filter(MessageAnnotation.id == annotation_id).first() | |||
| if not annotation: | |||
| raise NotFound("Annotation not found") | |||
| db.session.delete(annotation) | |||
| annotation_hit_histories = (db.session.query(AppAnnotationHitHistory) | |||
| .filter(AppAnnotationHitHistory.annotation_id == annotation_id) | |||
| .all() | |||
| ) | |||
| if annotation_hit_histories: | |||
| for annotation_hit_history in annotation_hit_histories: | |||
| db.session.delete(annotation_hit_history) | |||
| db.session.commit() | |||
| # if annotation reply is enabled , delete annotation index | |||
| app_annotation_setting = db.session.query(AppAnnotationSetting).filter( | |||
| AppAnnotationSetting.app_id == app_id | |||
| ).first() | |||
| if app_annotation_setting: | |||
| delete_annotation_index_task.delay(annotation.id, app_id, | |||
| current_user.current_tenant_id, | |||
| app_annotation_setting.collection_binding_id) | |||
| @classmethod | |||
| def batch_import_app_annotations(cls, app_id, file: FileStorage) -> dict: | |||
| # get app info | |||
| app = db.session.query(App).filter( | |||
| App.id == app_id, | |||
| App.tenant_id == current_user.current_tenant_id, | |||
| App.status == 'normal' | |||
| ).first() | |||
| if not app: | |||
| raise NotFound("App not found") | |||
| try: | |||
| # Skip the first row | |||
| df = pd.read_csv(file) | |||
| result = [] | |||
| for index, row in df.iterrows(): | |||
| content = { | |||
| 'question': row[0], | |||
| 'answer': row[1] | |||
| } | |||
| result.append(content) | |||
| if len(result) == 0: | |||
| raise ValueError("The CSV file is empty.") | |||
| # async job | |||
| job_id = str(uuid.uuid4()) | |||
| indexing_cache_key = 'app_annotation_batch_import_{}'.format(str(job_id)) | |||
| # send batch add segments task | |||
| redis_client.setnx(indexing_cache_key, 'waiting') | |||
| batch_import_annotations_task.delay(str(job_id), result, app_id, | |||
| current_user.current_tenant_id, current_user.id) | |||
| except Exception as e: | |||
| return { | |||
| 'error_msg': str(e) | |||
| } | |||
| return { | |||
| 'job_id': job_id, | |||
| 'job_status': 'waiting' | |||
| } | |||
| @classmethod | |||
| def get_annotation_hit_histories(cls, app_id: str, annotation_id: str, page, limit): | |||
| # get app info | |||
| app = db.session.query(App).filter( | |||
| App.id == app_id, | |||
| App.tenant_id == current_user.current_tenant_id, | |||
| App.status == 'normal' | |||
| ).first() | |||
| if not app: | |||
| raise NotFound("App not found") | |||
| annotation = db.session.query(MessageAnnotation).filter(MessageAnnotation.id == annotation_id).first() | |||
| if not annotation: | |||
| raise NotFound("Annotation not found") | |||
| annotation_hit_histories = (db.session.query(AppAnnotationHitHistory) | |||
| .filter(AppAnnotationHitHistory.app_id == app_id, | |||
| AppAnnotationHitHistory.annotation_id == annotation_id, | |||
| ) | |||
| .order_by(AppAnnotationHitHistory.created_at.desc()) | |||
| .paginate(page=page, per_page=limit, max_per_page=100, error_out=False)) | |||
| return annotation_hit_histories.items, annotation_hit_histories.total | |||
| @classmethod | |||
| def get_annotation_by_id(cls, annotation_id: str) -> MessageAnnotation | None: | |||
| annotation = db.session.query(MessageAnnotation).filter(MessageAnnotation.id == annotation_id).first() | |||
| if not annotation: | |||
| return None | |||
| return annotation | |||
| @classmethod | |||
| def add_annotation_history(cls, annotation_id: str, app_id: str, annotation_question: str, | |||
| annotation_content: str, query: str, user_id: str, | |||
| message_id: str, from_source: str, score: float): | |||
| # add hit count to annotation | |||
| db.session.query(MessageAnnotation).filter( | |||
| MessageAnnotation.id == annotation_id | |||
| ).update( | |||
| {MessageAnnotation.hit_count: MessageAnnotation.hit_count + 1}, | |||
| synchronize_session=False | |||
| ) | |||
| annotation_hit_history = AppAnnotationHitHistory( | |||
| annotation_id=annotation_id, | |||
| app_id=app_id, | |||
| account_id=user_id, | |||
| question=query, | |||
| source=from_source, | |||
| score=score, | |||
| message_id=message_id, | |||
| annotation_question=annotation_question, | |||
| annotation_content=annotation_content | |||
| ) | |||
| db.session.add(annotation_hit_history) | |||
| db.session.commit() | |||
| @classmethod | |||
| def get_app_annotation_setting_by_app_id(cls, app_id: str): | |||
| # get app info | |||
| app = db.session.query(App).filter( | |||
| App.id == app_id, | |||
| App.tenant_id == current_user.current_tenant_id, | |||
| App.status == 'normal' | |||
| ).first() | |||
| if not app: | |||
| raise NotFound("App not found") | |||
| annotation_setting = db.session.query(AppAnnotationSetting).filter( | |||
| AppAnnotationSetting.app_id == app_id).first() | |||
| if annotation_setting: | |||
| collection_binding_detail = annotation_setting.collection_binding_detail | |||
| return { | |||
| "id": annotation_setting.id, | |||
| "enabled": True, | |||
| "score_threshold": annotation_setting.score_threshold, | |||
| "embedding_model": { | |||
| "embedding_provider_name": collection_binding_detail.provider_name, | |||
| "embedding_model_name": collection_binding_detail.model_name | |||
| } | |||
| } | |||
| return { | |||
| "enabled": False | |||
| } | |||
| @classmethod | |||
| def update_app_annotation_setting(cls, app_id: str, annotation_setting_id: str, args: dict): | |||
| # get app info | |||
| app = db.session.query(App).filter( | |||
| App.id == app_id, | |||
| App.tenant_id == current_user.current_tenant_id, | |||
| App.status == 'normal' | |||
| ).first() | |||
| if not app: | |||
| raise NotFound("App not found") | |||
| annotation_setting = db.session.query(AppAnnotationSetting).filter( | |||
| AppAnnotationSetting.app_id == app_id, | |||
| AppAnnotationSetting.id == annotation_setting_id, | |||
| ).first() | |||
| if not annotation_setting: | |||
| raise NotFound("App annotation not found") | |||
| annotation_setting.score_threshold = args['score_threshold'] | |||
| annotation_setting.updated_user_id = current_user.id | |||
| annotation_setting.updated_at = datetime.datetime.utcnow() | |||
| db.session.add(annotation_setting) | |||
| db.session.commit() | |||
| collection_binding_detail = annotation_setting.collection_binding_detail | |||
| return { | |||
| "id": annotation_setting.id, | |||
| "enabled": True, | |||
| "score_threshold": annotation_setting.score_threshold, | |||
| "embedding_model": { | |||
| "embedding_provider_name": collection_binding_detail.provider_name, | |||
| "embedding_model_name": collection_binding_detail.model_name | |||
| } | |||
| } | |||
| @@ -138,7 +138,22 @@ class AppModelConfigService: | |||
| config["retriever_resource"]["enabled"] = False | |||
| if not isinstance(config["retriever_resource"]["enabled"], bool): | |||
| raise ValueError("enabled in speech_to_text must be of boolean type") | |||
| raise ValueError("enabled in retriever_resource must be of boolean type") | |||
| # annotation reply | |||
| if 'annotation_reply' not in config or not config["annotation_reply"]: | |||
| config["annotation_reply"] = { | |||
| "enabled": False | |||
| } | |||
| if not isinstance(config["annotation_reply"], dict): | |||
| raise ValueError("annotation_reply must be of dict type") | |||
| if "enabled" not in config["annotation_reply"] or not config["annotation_reply"]["enabled"]: | |||
| config["annotation_reply"]["enabled"] = False | |||
| if not isinstance(config["annotation_reply"]["enabled"], bool): | |||
| raise ValueError("enabled in annotation_reply must be of boolean type") | |||
| # more_like_this | |||
| if 'more_like_this' not in config or not config["more_like_this"]: | |||
| @@ -325,6 +340,7 @@ class AppModelConfigService: | |||
| "suggested_questions_after_answer": config["suggested_questions_after_answer"], | |||
| "speech_to_text": config["speech_to_text"], | |||
| "retriever_resource": config["retriever_resource"], | |||
| "annotation_reply": config["annotation_reply"], | |||
| "more_like_this": config["more_like_this"], | |||
| "sensitive_word_avoidance": config["sensitive_word_avoidance"], | |||
| "external_data_tools": config["external_data_tools"], | |||
| @@ -165,7 +165,8 @@ class CompletionService: | |||
| 'streaming': streaming, | |||
| 'is_model_config_override': is_model_config_override, | |||
| 'retriever_from': args['retriever_from'] if 'retriever_from' in args else 'dev', | |||
| 'auto_generate_name': auto_generate_name | |||
| 'auto_generate_name': auto_generate_name, | |||
| 'from_source': from_source | |||
| }) | |||
| generate_worker_thread.start() | |||
| @@ -193,7 +194,7 @@ class CompletionService: | |||
| query: str, inputs: dict, files: List[PromptMessageFile], | |||
| detached_user: Union[Account, EndUser], | |||
| detached_conversation: Optional[Conversation], streaming: bool, is_model_config_override: bool, | |||
| retriever_from: str = 'dev', auto_generate_name: bool = True): | |||
| retriever_from: str = 'dev', auto_generate_name: bool = True, from_source: str = 'console'): | |||
| with flask_app.app_context(): | |||
| # fixed the state of the model object when it detached from the original session | |||
| user = db.session.merge(detached_user) | |||
| @@ -218,7 +219,8 @@ class CompletionService: | |||
| streaming=streaming, | |||
| is_override=is_model_config_override, | |||
| retriever_from=retriever_from, | |||
| auto_generate_name=auto_generate_name | |||
| auto_generate_name=auto_generate_name, | |||
| from_source=from_source | |||
| ) | |||
| except (ConversationTaskInterruptException, ConversationTaskStoppedException): | |||
| pass | |||
| @@ -385,6 +387,9 @@ class CompletionService: | |||
| result = json.loads(result) | |||
| if result.get('error'): | |||
| cls.handle_error(result) | |||
| if result['event'] == 'annotation' and 'data' in result: | |||
| message_result['annotation'] = result.get('data') | |||
| return cls.get_blocking_annotation_message_response_data(message_result) | |||
| if result['event'] == 'message' and 'data' in result: | |||
| message_result['message'] = result.get('data') | |||
| if result['event'] == 'message_end' and 'data' in result: | |||
| @@ -427,6 +432,9 @@ class CompletionService: | |||
| elif event == 'agent_thought': | |||
| yield "data: " + json.dumps( | |||
| cls.get_agent_thought_response_data(result.get('data'))) + "\n\n" | |||
| elif event == 'annotation': | |||
| yield "data: " + json.dumps( | |||
| cls.get_annotation_response_data(result.get('data'))) + "\n\n" | |||
| elif event == 'message_end': | |||
| yield "data: " + json.dumps( | |||
| cls.get_message_end_data(result.get('data'))) + "\n\n" | |||
| @@ -499,6 +507,25 @@ class CompletionService: | |||
| return response_data | |||
| @classmethod | |||
| def get_blocking_annotation_message_response_data(cls, data: dict): | |||
| message = data.get('annotation') | |||
| response_data = { | |||
| 'event': 'annotation', | |||
| 'task_id': message.get('task_id'), | |||
| 'id': message.get('message_id'), | |||
| 'answer': message.get('text'), | |||
| 'metadata': {}, | |||
| 'created_at': int(time.time()), | |||
| 'annotation_id': message.get('annotation_id'), | |||
| 'annotation_author_name': message.get('annotation_author_name') | |||
| } | |||
| if message.get('mode') == 'chat': | |||
| response_data['conversation_id'] = message.get('conversation_id') | |||
| return response_data | |||
| @classmethod | |||
| def get_message_end_data(cls, data: dict): | |||
| response_data = { | |||
| @@ -551,6 +578,23 @@ class CompletionService: | |||
| return response_data | |||
| @classmethod | |||
| def get_annotation_response_data(cls, data: dict): | |||
| response_data = { | |||
| 'event': 'annotation', | |||
| 'task_id': data.get('task_id'), | |||
| 'id': data.get('message_id'), | |||
| 'answer': data.get('text'), | |||
| 'created_at': int(time.time()), | |||
| 'annotation_id': data.get('annotation_id'), | |||
| 'annotation_author_name': data.get('annotation_author_name'), | |||
| } | |||
| if data.get('mode') == 'chat': | |||
| response_data['conversation_id'] = data.get('conversation_id') | |||
| return response_data | |||
| @classmethod | |||
| def handle_error(cls, result: dict): | |||
| logging.debug("error: %s", result) | |||
| @@ -33,10 +33,7 @@ from tasks.clean_notion_document_task import clean_notion_document_task | |||
| from tasks.deal_dataset_vector_index_task import deal_dataset_vector_index_task | |||
| from tasks.document_indexing_task import document_indexing_task | |||
| from tasks.document_indexing_update_task import document_indexing_update_task | |||
| from tasks.create_segment_to_index_task import create_segment_to_index_task | |||
| from tasks.update_segment_index_task import update_segment_index_task | |||
| from tasks.recover_document_indexing_task import recover_document_indexing_task | |||
| from tasks.update_segment_keyword_index_task import update_segment_keyword_index_task | |||
| from tasks.delete_segment_from_index_task import delete_segment_from_index_task | |||
| @@ -1175,10 +1172,12 @@ class SegmentService: | |||
| class DatasetCollectionBindingService: | |||
| @classmethod | |||
| def get_dataset_collection_binding(cls, provider_name: str, model_name: str) -> DatasetCollectionBinding: | |||
| def get_dataset_collection_binding(cls, provider_name: str, model_name: str, | |||
| collection_type: str = 'dataset') -> DatasetCollectionBinding: | |||
| dataset_collection_binding = db.session.query(DatasetCollectionBinding). \ | |||
| filter(DatasetCollectionBinding.provider_name == provider_name, | |||
| DatasetCollectionBinding.model_name == model_name). \ | |||
| DatasetCollectionBinding.model_name == model_name, | |||
| DatasetCollectionBinding.type == collection_type). \ | |||
| order_by(DatasetCollectionBinding.created_at). \ | |||
| first() | |||
| @@ -1186,8 +1185,20 @@ class DatasetCollectionBindingService: | |||
| dataset_collection_binding = DatasetCollectionBinding( | |||
| provider_name=provider_name, | |||
| model_name=model_name, | |||
| collection_name="Vector_index_" + str(uuid.uuid4()).replace("-", "_") + '_Node' | |||
| collection_name="Vector_index_" + str(uuid.uuid4()).replace("-", "_") + '_Node', | |||
| type=collection_type | |||
| ) | |||
| db.session.add(dataset_collection_binding) | |||
| db.session.flush() | |||
| db.session.commit() | |||
| return dataset_collection_binding | |||
| @classmethod | |||
| def get_dataset_collection_binding_by_id_and_type(cls, collection_binding_id: str, | |||
| collection_type: str = 'dataset') -> DatasetCollectionBinding: | |||
| dataset_collection_binding = db.session.query(DatasetCollectionBinding). \ | |||
| filter(DatasetCollectionBinding.id == collection_binding_id, | |||
| DatasetCollectionBinding.type == collection_type). \ | |||
| order_by(DatasetCollectionBinding.created_at). \ | |||
| first() | |||
| return dataset_collection_binding | |||
| @@ -0,0 +1,59 @@ | |||
| import logging | |||
| import time | |||
| import click | |||
| from celery import shared_task | |||
| from langchain.schema import Document | |||
| from core.index.index import IndexBuilder | |||
| from models.dataset import Dataset | |||
| from services.dataset_service import DatasetCollectionBindingService | |||
| @shared_task(queue='dataset') | |||
| def add_annotation_to_index_task(annotation_id: str, question: str, tenant_id: str, app_id: str, | |||
| collection_binding_id: str): | |||
| """ | |||
| Add annotation to index. | |||
| :param annotation_id: annotation id | |||
| :param question: question | |||
| :param tenant_id: tenant id | |||
| :param app_id: app id | |||
| :param collection_binding_id: embedding binding id | |||
| Usage: clean_dataset_task.delay(dataset_id, tenant_id, indexing_technique, index_struct) | |||
| """ | |||
| logging.info(click.style('Start build index for annotation: {}'.format(annotation_id), fg='green')) | |||
| start_at = time.perf_counter() | |||
| try: | |||
| dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type( | |||
| collection_binding_id, | |||
| 'annotation' | |||
| ) | |||
| dataset = Dataset( | |||
| id=app_id, | |||
| tenant_id=tenant_id, | |||
| indexing_technique='high_quality', | |||
| collection_binding_id=dataset_collection_binding.id | |||
| ) | |||
| document = Document( | |||
| page_content=question, | |||
| metadata={ | |||
| "annotation_id": annotation_id, | |||
| "app_id": app_id, | |||
| "doc_id": annotation_id | |||
| } | |||
| ) | |||
| index = IndexBuilder.get_index(dataset, 'high_quality') | |||
| if index: | |||
| index.add_texts([document]) | |||
| end_at = time.perf_counter() | |||
| logging.info( | |||
| click.style( | |||
| 'Build index successful for annotation: {} latency: {}'.format(annotation_id, end_at - start_at), | |||
| fg='green')) | |||
| except Exception: | |||
| logging.exception("Build index for annotation failed") | |||
| @@ -0,0 +1,99 @@ | |||
| import json | |||
| import logging | |||
| import time | |||
| import click | |||
| from celery import shared_task | |||
| from langchain.schema import Document | |||
| from werkzeug.exceptions import NotFound | |||
| from core.index.index import IndexBuilder | |||
| from extensions.ext_database import db | |||
| from extensions.ext_redis import redis_client | |||
| from models.dataset import Dataset | |||
| from models.model import MessageAnnotation, App, AppAnnotationSetting | |||
| from services.dataset_service import DatasetCollectionBindingService | |||
| @shared_task(queue='dataset') | |||
| def batch_import_annotations_task(job_id: str, content_list: list[dict], app_id: str, tenant_id: str, | |||
| user_id: str): | |||
| """ | |||
| Add annotation to index. | |||
| :param job_id: job_id | |||
| :param content_list: content list | |||
| :param tenant_id: tenant id | |||
| :param app_id: app id | |||
| :param user_id: user_id | |||
| """ | |||
| logging.info(click.style('Start batch import annotation: {}'.format(job_id), fg='green')) | |||
| start_at = time.perf_counter() | |||
| indexing_cache_key = 'app_annotation_batch_import_{}'.format(str(job_id)) | |||
| # get app info | |||
| app = db.session.query(App).filter( | |||
| App.id == app_id, | |||
| App.tenant_id == tenant_id, | |||
| App.status == 'normal' | |||
| ).first() | |||
| if app: | |||
| try: | |||
| documents = [] | |||
| for content in content_list: | |||
| annotation = MessageAnnotation( | |||
| app_id=app.id, | |||
| content=content['answer'], | |||
| question=content['question'], | |||
| account_id=user_id | |||
| ) | |||
| db.session.add(annotation) | |||
| db.session.flush() | |||
| document = Document( | |||
| page_content=content['question'], | |||
| metadata={ | |||
| "annotation_id": annotation.id, | |||
| "app_id": app_id, | |||
| "doc_id": annotation.id | |||
| } | |||
| ) | |||
| documents.append(document) | |||
| # if annotation reply is enabled , batch add annotations' index | |||
| app_annotation_setting = db.session.query(AppAnnotationSetting).filter( | |||
| AppAnnotationSetting.app_id == app_id | |||
| ).first() | |||
| if app_annotation_setting: | |||
| dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type( | |||
| app_annotation_setting.collection_binding_id, | |||
| 'annotation' | |||
| ) | |||
| if not dataset_collection_binding: | |||
| raise NotFound("App annotation setting not found") | |||
| dataset = Dataset( | |||
| id=app_id, | |||
| tenant_id=tenant_id, | |||
| indexing_technique='high_quality', | |||
| embedding_model_provider=dataset_collection_binding.provider_name, | |||
| embedding_model=dataset_collection_binding.model_name, | |||
| collection_binding_id=dataset_collection_binding.id | |||
| ) | |||
| index = IndexBuilder.get_index(dataset, 'high_quality') | |||
| if index: | |||
| index.add_texts(documents) | |||
| db.session.commit() | |||
| redis_client.setex(indexing_cache_key, 600, 'completed') | |||
| end_at = time.perf_counter() | |||
| logging.info( | |||
| click.style( | |||
| 'Build index successful for batch import annotation: {} latency: {}'.format(job_id, end_at - start_at), | |||
| fg='green')) | |||
| except Exception as e: | |||
| db.session.rollback() | |||
| redis_client.setex(indexing_cache_key, 600, 'error') | |||
| indexing_error_msg_key = 'app_annotation_batch_import_error_msg_{}'.format(str(job_id)) | |||
| redis_client.setex(indexing_error_msg_key, 600, str(e)) | |||
| logging.exception("Build index for batch import annotations failed") | |||
| @@ -0,0 +1,45 @@ | |||
| import datetime | |||
| import logging | |||
| import time | |||
| import click | |||
| from celery import shared_task | |||
| from core.index.index import IndexBuilder | |||
| from models.dataset import Dataset | |||
| from services.dataset_service import DatasetCollectionBindingService | |||
| @shared_task(queue='dataset') | |||
| def delete_annotation_index_task(annotation_id: str, app_id: str, tenant_id: str, | |||
| collection_binding_id: str): | |||
| """ | |||
| Async delete annotation index task | |||
| """ | |||
| logging.info(click.style('Start delete app annotation index: {}'.format(app_id), fg='green')) | |||
| start_at = time.perf_counter() | |||
| try: | |||
| dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type( | |||
| collection_binding_id, | |||
| 'annotation' | |||
| ) | |||
| dataset = Dataset( | |||
| id=app_id, | |||
| tenant_id=tenant_id, | |||
| indexing_technique='high_quality', | |||
| collection_binding_id=dataset_collection_binding.id | |||
| ) | |||
| vector_index = IndexBuilder.get_default_high_quality_index(dataset) | |||
| if vector_index: | |||
| try: | |||
| vector_index.delete_by_metadata_field('annotation_id', annotation_id) | |||
| except Exception: | |||
| logging.exception("Delete annotation index failed when annotation deleted.") | |||
| end_at = time.perf_counter() | |||
| logging.info( | |||
| click.style('App annotations index deleted : {} latency: {}'.format(app_id, end_at - start_at), | |||
| fg='green')) | |||
| except Exception as e: | |||
| logging.exception("Annotation deleted index failed:{}".format(str(e))) | |||
| @@ -0,0 +1,74 @@ | |||
| import datetime | |||
| import logging | |||
| import time | |||
| import click | |||
| from celery import shared_task | |||
| from werkzeug.exceptions import NotFound | |||
| from core.index.index import IndexBuilder | |||
| from extensions.ext_database import db | |||
| from extensions.ext_redis import redis_client | |||
| from models.dataset import Dataset | |||
| from models.model import MessageAnnotation, App, AppAnnotationSetting | |||
| @shared_task(queue='dataset') | |||
| def disable_annotation_reply_task(job_id: str, app_id: str, tenant_id: str): | |||
| """ | |||
| Async enable annotation reply task | |||
| """ | |||
| logging.info(click.style('Start delete app annotations index: {}'.format(app_id), fg='green')) | |||
| start_at = time.perf_counter() | |||
| # get app info | |||
| app = db.session.query(App).filter( | |||
| App.id == app_id, | |||
| App.tenant_id == tenant_id, | |||
| App.status == 'normal' | |||
| ).first() | |||
| if not app: | |||
| raise NotFound("App not found") | |||
| app_annotation_setting = db.session.query(AppAnnotationSetting).filter( | |||
| AppAnnotationSetting.app_id == app_id | |||
| ).first() | |||
| if not app_annotation_setting: | |||
| raise NotFound("App annotation setting not found") | |||
| disable_app_annotation_key = 'disable_app_annotation_{}'.format(str(app_id)) | |||
| disable_app_annotation_job_key = 'disable_app_annotation_job_{}'.format(str(job_id)) | |||
| try: | |||
| dataset = Dataset( | |||
| id=app_id, | |||
| tenant_id=tenant_id, | |||
| indexing_technique='high_quality', | |||
| collection_binding_id=app_annotation_setting.collection_binding_id | |||
| ) | |||
| vector_index = IndexBuilder.get_default_high_quality_index(dataset) | |||
| if vector_index: | |||
| try: | |||
| vector_index.delete_by_metadata_field('app_id', app_id) | |||
| except Exception: | |||
| logging.exception("Delete doc index failed when dataset deleted.") | |||
| redis_client.setex(disable_app_annotation_job_key, 600, 'completed') | |||
| # delete annotation setting | |||
| db.session.delete(app_annotation_setting) | |||
| db.session.commit() | |||
| end_at = time.perf_counter() | |||
| logging.info( | |||
| click.style('App annotations index deleted : {} latency: {}'.format(app_id, end_at - start_at), | |||
| fg='green')) | |||
| except Exception as e: | |||
| logging.exception("Annotation batch deleted index failed:{}".format(str(e))) | |||
| redis_client.setex(disable_app_annotation_job_key, 600, 'error') | |||
| disable_app_annotation_error_key = 'disable_app_annotation_error_{}'.format(str(job_id)) | |||
| redis_client.setex(disable_app_annotation_error_key, 600, str(e)) | |||
| finally: | |||
| redis_client.delete(disable_app_annotation_key) | |||
| @@ -0,0 +1,106 @@ | |||
| import datetime | |||
| import logging | |||
| import time | |||
| import click | |||
| from celery import shared_task | |||
| from langchain.schema import Document | |||
| from werkzeug.exceptions import NotFound | |||
| from core.index.index import IndexBuilder | |||
| from extensions.ext_database import db | |||
| from extensions.ext_redis import redis_client | |||
| from models.dataset import Dataset | |||
| from models.model import MessageAnnotation, App, AppAnnotationSetting | |||
| from services.dataset_service import DatasetCollectionBindingService | |||
| @shared_task(queue='dataset') | |||
| def enable_annotation_reply_task(job_id: str, app_id: str, user_id: str, tenant_id: str, score_threshold: float, | |||
| embedding_provider_name: str, embedding_model_name: str): | |||
| """ | |||
| Async enable annotation reply task | |||
| """ | |||
| logging.info(click.style('Start add app annotation to index: {}'.format(app_id), fg='green')) | |||
| start_at = time.perf_counter() | |||
| # get app info | |||
| app = db.session.query(App).filter( | |||
| App.id == app_id, | |||
| App.tenant_id == tenant_id, | |||
| App.status == 'normal' | |||
| ).first() | |||
| if not app: | |||
| raise NotFound("App not found") | |||
| annotations = db.session.query(MessageAnnotation).filter(MessageAnnotation.app_id == app_id).all() | |||
| enable_app_annotation_key = 'enable_app_annotation_{}'.format(str(app_id)) | |||
| enable_app_annotation_job_key = 'enable_app_annotation_job_{}'.format(str(job_id)) | |||
| try: | |||
| documents = [] | |||
| dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( | |||
| embedding_provider_name, | |||
| embedding_model_name, | |||
| 'annotation' | |||
| ) | |||
| annotation_setting = db.session.query(AppAnnotationSetting).filter( | |||
| AppAnnotationSetting.app_id == app_id).first() | |||
| if annotation_setting: | |||
| annotation_setting.score_threshold = score_threshold | |||
| annotation_setting.collection_binding_id = dataset_collection_binding.id | |||
| annotation_setting.updated_user_id = user_id | |||
| annotation_setting.updated_at = datetime.datetime.utcnow() | |||
| db.session.add(annotation_setting) | |||
| else: | |||
| new_app_annotation_setting = AppAnnotationSetting( | |||
| app_id=app_id, | |||
| score_threshold=score_threshold, | |||
| collection_binding_id=dataset_collection_binding.id, | |||
| created_user_id=user_id, | |||
| updated_user_id=user_id | |||
| ) | |||
| db.session.add(new_app_annotation_setting) | |||
| dataset = Dataset( | |||
| id=app_id, | |||
| tenant_id=tenant_id, | |||
| indexing_technique='high_quality', | |||
| embedding_model_provider=embedding_provider_name, | |||
| embedding_model=embedding_model_name, | |||
| collection_binding_id=dataset_collection_binding.id | |||
| ) | |||
| if annotations: | |||
| for annotation in annotations: | |||
| document = Document( | |||
| page_content=annotation.question, | |||
| metadata={ | |||
| "annotation_id": annotation.id, | |||
| "app_id": app_id, | |||
| "doc_id": annotation.id | |||
| } | |||
| ) | |||
| documents.append(document) | |||
| index = IndexBuilder.get_index(dataset, 'high_quality') | |||
| if index: | |||
| try: | |||
| index.delete_by_metadata_field('app_id', app_id) | |||
| except Exception as e: | |||
| logging.info( | |||
| click.style('Delete annotation index error: {}'.format(str(e)), | |||
| fg='red')) | |||
| index.add_texts(documents) | |||
| db.session.commit() | |||
| redis_client.setex(enable_app_annotation_job_key, 600, 'completed') | |||
| end_at = time.perf_counter() | |||
| logging.info( | |||
| click.style('App annotations added to index: {} latency: {}'.format(app_id, end_at - start_at), | |||
| fg='green')) | |||
| except Exception as e: | |||
| logging.exception("Annotation batch created index failed:{}".format(str(e))) | |||
| redis_client.setex(enable_app_annotation_job_key, 600, 'error') | |||
| enable_app_annotation_error_key = 'enable_app_annotation_error_{}'.format(str(job_id)) | |||
| redis_client.setex(enable_app_annotation_error_key, 600, str(e)) | |||
| db.session.rollback() | |||
| finally: | |||
| redis_client.delete(enable_app_annotation_key) | |||
| @@ -0,0 +1,63 @@ | |||
| import logging | |||
| import time | |||
| import click | |||
| from celery import shared_task | |||
| from langchain.schema import Document | |||
| from core.index.index import IndexBuilder | |||
| from models.dataset import Dataset | |||
| from services.dataset_service import DatasetCollectionBindingService | |||
| @shared_task(queue='dataset') | |||
| def update_annotation_to_index_task(annotation_id: str, question: str, tenant_id: str, app_id: str, | |||
| collection_binding_id: str): | |||
| """ | |||
| Update annotation to index. | |||
| :param annotation_id: annotation id | |||
| :param question: question | |||
| :param tenant_id: tenant id | |||
| :param app_id: app id | |||
| :param collection_binding_id: embedding binding id | |||
| Usage: clean_dataset_task.delay(dataset_id, tenant_id, indexing_technique, index_struct) | |||
| """ | |||
| logging.info(click.style('Start update index for annotation: {}'.format(annotation_id), fg='green')) | |||
| start_at = time.perf_counter() | |||
| try: | |||
| dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type( | |||
| collection_binding_id, | |||
| 'annotation' | |||
| ) | |||
| dataset = Dataset( | |||
| id=app_id, | |||
| tenant_id=tenant_id, | |||
| indexing_technique='high_quality', | |||
| embedding_model_provider=dataset_collection_binding.provider_name, | |||
| embedding_model=dataset_collection_binding.model_name, | |||
| collection_binding_id=dataset_collection_binding.id | |||
| ) | |||
| document = Document( | |||
| page_content=question, | |||
| metadata={ | |||
| "annotation_id": annotation_id, | |||
| "app_id": app_id, | |||
| "doc_id": annotation_id | |||
| } | |||
| ) | |||
| index = IndexBuilder.get_index(dataset, 'high_quality') | |||
| if index: | |||
| index.delete_by_metadata_field('annotation_id', annotation_id) | |||
| index.add_texts([document]) | |||
| end_at = time.perf_counter() | |||
| logging.info( | |||
| click.style( | |||
| 'Build index successful for annotation: {} latency: {}'.format(annotation_id, end_at - start_at), | |||
| fg='green')) | |||
| except Exception: | |||
| logging.exception("Build index for annotation failed") | |||