Co-authored-by: jyong <jyong@dify.ai> Co-authored-by: StyleZhang <jasonapring2015@outlook.com>tags/0.3.24
| @@ -81,6 +81,7 @@ class BaseApiKeyListResource(Resource): | |||
| key = ApiToken.generate_api_key(self.token_prefix, 24) | |||
| api_token = ApiToken() | |||
| setattr(api_token, self.resource_id_field, resource_id) | |||
| api_token.tenant_id = current_user.current_tenant_id | |||
| api_token.token = key | |||
| api_token.type = self.resource_type | |||
| db.session.add(api_token) | |||
| @@ -19,41 +19,13 @@ from core.model_providers.model_factory import ModelFactory | |||
| from core.model_providers.model_provider_factory import ModelProviderFactory | |||
| from core.model_providers.models.entity.model_params import ModelType | |||
| from events.app_event import app_was_created, app_was_deleted | |||
| from fields.app_fields import app_pagination_fields, app_detail_fields, template_list_fields, \ | |||
| app_detail_fields_with_site | |||
| from libs.helper import TimestampField | |||
| from extensions.ext_database import db | |||
| from models.model import App, AppModelConfig, Site | |||
| from services.app_model_config_service import AppModelConfigService | |||
| model_config_fields = { | |||
| 'opening_statement': fields.String, | |||
| 'suggested_questions': fields.Raw(attribute='suggested_questions_list'), | |||
| 'suggested_questions_after_answer': fields.Raw(attribute='suggested_questions_after_answer_dict'), | |||
| 'speech_to_text': fields.Raw(attribute='speech_to_text_dict'), | |||
| 'retriever_resource': fields.Raw(attribute='retriever_resource_dict'), | |||
| 'more_like_this': fields.Raw(attribute='more_like_this_dict'), | |||
| 'sensitive_word_avoidance': fields.Raw(attribute='sensitive_word_avoidance_dict'), | |||
| 'model': fields.Raw(attribute='model_dict'), | |||
| 'user_input_form': fields.Raw(attribute='user_input_form_list'), | |||
| 'dataset_query_variable': fields.String, | |||
| 'pre_prompt': fields.String, | |||
| 'agent_mode': fields.Raw(attribute='agent_mode_dict'), | |||
| } | |||
| app_detail_fields = { | |||
| 'id': fields.String, | |||
| 'name': fields.String, | |||
| 'mode': fields.String, | |||
| 'icon': fields.String, | |||
| 'icon_background': fields.String, | |||
| 'enable_site': fields.Boolean, | |||
| 'enable_api': fields.Boolean, | |||
| 'api_rpm': fields.Integer, | |||
| 'api_rph': fields.Integer, | |||
| 'is_demo': fields.Boolean, | |||
| 'model_config': fields.Nested(model_config_fields, attribute='app_model_config'), | |||
| 'created_at': TimestampField | |||
| } | |||
| def _get_app(app_id, tenant_id): | |||
| app = db.session.query(App).filter(App.id == app_id, App.tenant_id == tenant_id).first() | |||
| @@ -63,35 +35,6 @@ def _get_app(app_id, tenant_id): | |||
| class AppListApi(Resource): | |||
| prompt_config_fields = { | |||
| 'prompt_template': fields.String, | |||
| } | |||
| model_config_partial_fields = { | |||
| 'model': fields.Raw(attribute='model_dict'), | |||
| 'pre_prompt': fields.String, | |||
| } | |||
| app_partial_fields = { | |||
| 'id': fields.String, | |||
| 'name': fields.String, | |||
| 'mode': fields.String, | |||
| 'icon': fields.String, | |||
| 'icon_background': fields.String, | |||
| 'enable_site': fields.Boolean, | |||
| 'enable_api': fields.Boolean, | |||
| 'is_demo': fields.Boolean, | |||
| 'model_config': fields.Nested(model_config_partial_fields, attribute='app_model_config'), | |||
| 'created_at': TimestampField | |||
| } | |||
| app_pagination_fields = { | |||
| 'page': fields.Integer, | |||
| 'limit': fields.Integer(attribute='per_page'), | |||
| 'total': fields.Integer, | |||
| 'has_more': fields.Boolean(attribute='has_next'), | |||
| 'data': fields.List(fields.Nested(app_partial_fields), attribute='items') | |||
| } | |||
| @setup_required | |||
| @login_required | |||
| @@ -238,18 +181,6 @@ class AppListApi(Resource): | |||
| class AppTemplateApi(Resource): | |||
| template_fields = { | |||
| 'name': fields.String, | |||
| 'icon': fields.String, | |||
| 'icon_background': fields.String, | |||
| 'description': fields.String, | |||
| 'mode': fields.String, | |||
| 'model_config': fields.Nested(model_config_fields), | |||
| } | |||
| template_list_fields = { | |||
| 'data': fields.List(fields.Nested(template_fields)), | |||
| } | |||
| @setup_required | |||
| @login_required | |||
| @@ -268,38 +199,6 @@ class AppTemplateApi(Resource): | |||
| class AppApi(Resource): | |||
| site_fields = { | |||
| 'access_token': fields.String(attribute='code'), | |||
| 'code': fields.String, | |||
| 'title': fields.String, | |||
| 'icon': fields.String, | |||
| 'icon_background': fields.String, | |||
| 'description': fields.String, | |||
| 'default_language': fields.String, | |||
| 'customize_domain': fields.String, | |||
| 'copyright': fields.String, | |||
| 'privacy_policy': fields.String, | |||
| 'customize_token_strategy': fields.String, | |||
| 'prompt_public': fields.Boolean, | |||
| 'app_base_url': fields.String, | |||
| } | |||
| app_detail_fields_with_site = { | |||
| 'id': fields.String, | |||
| 'name': fields.String, | |||
| 'mode': fields.String, | |||
| 'icon': fields.String, | |||
| 'icon_background': fields.String, | |||
| 'enable_site': fields.Boolean, | |||
| 'enable_api': fields.Boolean, | |||
| 'api_rpm': fields.Integer, | |||
| 'api_rph': fields.Integer, | |||
| 'is_demo': fields.Boolean, | |||
| 'model_config': fields.Nested(model_config_fields, attribute='app_model_config'), | |||
| 'site': fields.Nested(site_fields), | |||
| 'api_base_url': fields.String, | |||
| 'created_at': TimestampField | |||
| } | |||
| @setup_required | |||
| @login_required | |||
| @@ -13,107 +13,14 @@ from controllers.console import api | |||
| from controllers.console.app import _get_app | |||
| from controllers.console.setup import setup_required | |||
| from controllers.console.wraps import account_initialization_required | |||
| from fields.conversation_fields import conversation_pagination_fields, conversation_detail_fields, \ | |||
| conversation_message_detail_fields, conversation_with_summary_pagination_fields | |||
| from libs.helper import TimestampField, datetime_string, uuid_value | |||
| from extensions.ext_database import db | |||
| from models.model import Message, MessageAnnotation, Conversation | |||
| account_fields = { | |||
| 'id': fields.String, | |||
| 'name': fields.String, | |||
| 'email': fields.String | |||
| } | |||
| feedback_fields = { | |||
| 'rating': fields.String, | |||
| 'content': fields.String, | |||
| 'from_source': fields.String, | |||
| 'from_end_user_id': fields.String, | |||
| 'from_account': fields.Nested(account_fields, allow_null=True), | |||
| } | |||
| annotation_fields = { | |||
| 'content': fields.String, | |||
| 'account': fields.Nested(account_fields, allow_null=True), | |||
| 'created_at': TimestampField | |||
| } | |||
| message_detail_fields = { | |||
| 'id': fields.String, | |||
| 'conversation_id': fields.String, | |||
| 'inputs': fields.Raw, | |||
| 'query': fields.String, | |||
| 'message': fields.Raw, | |||
| 'message_tokens': fields.Integer, | |||
| 'answer': fields.String, | |||
| 'answer_tokens': fields.Integer, | |||
| 'provider_response_latency': fields.Float, | |||
| 'from_source': fields.String, | |||
| 'from_end_user_id': fields.String, | |||
| 'from_account_id': fields.String, | |||
| 'feedbacks': fields.List(fields.Nested(feedback_fields)), | |||
| 'annotation': fields.Nested(annotation_fields, allow_null=True), | |||
| 'created_at': TimestampField | |||
| } | |||
| feedback_stat_fields = { | |||
| 'like': fields.Integer, | |||
| 'dislike': fields.Integer | |||
| } | |||
| model_config_fields = { | |||
| 'opening_statement': fields.String, | |||
| 'suggested_questions': fields.Raw, | |||
| 'model': fields.Raw, | |||
| 'user_input_form': fields.Raw, | |||
| 'pre_prompt': fields.String, | |||
| 'agent_mode': fields.Raw, | |||
| } | |||
| class CompletionConversationApi(Resource): | |||
| class MessageTextField(fields.Raw): | |||
| def format(self, value): | |||
| return value[0]['text'] if value else '' | |||
| simple_configs_fields = { | |||
| 'prompt_template': fields.String, | |||
| } | |||
| simple_model_config_fields = { | |||
| 'model': fields.Raw(attribute='model_dict'), | |||
| 'pre_prompt': fields.String, | |||
| } | |||
| simple_message_detail_fields = { | |||
| 'inputs': fields.Raw, | |||
| 'query': fields.String, | |||
| 'message': MessageTextField, | |||
| 'answer': fields.String, | |||
| } | |||
| conversation_fields = { | |||
| 'id': fields.String, | |||
| 'status': fields.String, | |||
| 'from_source': fields.String, | |||
| 'from_end_user_id': fields.String, | |||
| 'from_end_user_session_id': fields.String(), | |||
| 'from_account_id': fields.String, | |||
| 'read_at': TimestampField, | |||
| 'created_at': TimestampField, | |||
| 'annotation': fields.Nested(annotation_fields, allow_null=True), | |||
| 'model_config': fields.Nested(simple_model_config_fields), | |||
| 'user_feedback_stats': fields.Nested(feedback_stat_fields), | |||
| 'admin_feedback_stats': fields.Nested(feedback_stat_fields), | |||
| 'message': fields.Nested(simple_message_detail_fields, attribute='first_message') | |||
| } | |||
| conversation_pagination_fields = { | |||
| 'page': fields.Integer, | |||
| 'limit': fields.Integer(attribute='per_page'), | |||
| 'total': fields.Integer, | |||
| 'has_more': fields.Boolean(attribute='has_next'), | |||
| 'data': fields.List(fields.Nested(conversation_fields), attribute='items') | |||
| } | |||
| @setup_required | |||
| @login_required | |||
| @@ -191,21 +98,11 @@ class CompletionConversationApi(Resource): | |||
| class CompletionConversationDetailApi(Resource): | |||
| conversation_detail_fields = { | |||
| 'id': fields.String, | |||
| 'status': fields.String, | |||
| 'from_source': fields.String, | |||
| 'from_end_user_id': fields.String, | |||
| 'from_account_id': fields.String, | |||
| 'created_at': TimestampField, | |||
| 'model_config': fields.Nested(model_config_fields), | |||
| 'message': fields.Nested(message_detail_fields, attribute='first_message'), | |||
| } | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @marshal_with(conversation_detail_fields) | |||
| @marshal_with(conversation_message_detail_fields) | |||
| def get(self, app_id, conversation_id): | |||
| app_id = str(app_id) | |||
| conversation_id = str(conversation_id) | |||
| @@ -234,44 +131,11 @@ class CompletionConversationDetailApi(Resource): | |||
| class ChatConversationApi(Resource): | |||
| simple_configs_fields = { | |||
| 'prompt_template': fields.String, | |||
| } | |||
| simple_model_config_fields = { | |||
| 'model': fields.Raw(attribute='model_dict'), | |||
| 'pre_prompt': fields.String, | |||
| } | |||
| conversation_fields = { | |||
| 'id': fields.String, | |||
| 'status': fields.String, | |||
| 'from_source': fields.String, | |||
| 'from_end_user_id': fields.String, | |||
| 'from_end_user_session_id': fields.String, | |||
| 'from_account_id': fields.String, | |||
| 'summary': fields.String(attribute='summary_or_query'), | |||
| 'read_at': TimestampField, | |||
| 'created_at': TimestampField, | |||
| 'annotated': fields.Boolean, | |||
| 'model_config': fields.Nested(simple_model_config_fields), | |||
| 'message_count': fields.Integer, | |||
| 'user_feedback_stats': fields.Nested(feedback_stat_fields), | |||
| 'admin_feedback_stats': fields.Nested(feedback_stat_fields) | |||
| } | |||
| conversation_pagination_fields = { | |||
| 'page': fields.Integer, | |||
| 'limit': fields.Integer(attribute='per_page'), | |||
| 'total': fields.Integer, | |||
| 'has_more': fields.Boolean(attribute='has_next'), | |||
| 'data': fields.List(fields.Nested(conversation_fields), attribute='items') | |||
| } | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @marshal_with(conversation_pagination_fields) | |||
| @marshal_with(conversation_with_summary_pagination_fields) | |||
| def get(self, app_id): | |||
| app_id = str(app_id) | |||
| @@ -356,19 +220,6 @@ class ChatConversationApi(Resource): | |||
| class ChatConversationDetailApi(Resource): | |||
| conversation_detail_fields = { | |||
| 'id': fields.String, | |||
| 'status': fields.String, | |||
| 'from_source': fields.String, | |||
| 'from_end_user_id': fields.String, | |||
| 'from_account_id': fields.String, | |||
| 'created_at': TimestampField, | |||
| 'annotated': fields.Boolean, | |||
| 'model_config': fields.Nested(model_config_fields), | |||
| 'message_count': fields.Integer, | |||
| 'user_feedback_stats': fields.Nested(feedback_stat_fields), | |||
| 'admin_feedback_stats': fields.Nested(feedback_stat_fields) | |||
| } | |||
| @setup_required | |||
| @login_required | |||
| @@ -17,6 +17,7 @@ from controllers.console.wraps import account_initialization_required | |||
| from core.model_providers.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \ | |||
| ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError | |||
| from core.login.login import login_required | |||
| from fields.conversation_fields import message_detail_fields | |||
| from libs.helper import uuid_value, TimestampField | |||
| from libs.infinite_scroll_pagination import InfiniteScrollPagination | |||
| from extensions.ext_database import db | |||
| @@ -27,44 +28,6 @@ from services.errors.conversation import ConversationNotExistsError | |||
| from services.errors.message import MessageNotExistsError | |||
| from services.message_service import MessageService | |||
| account_fields = { | |||
| 'id': fields.String, | |||
| 'name': fields.String, | |||
| 'email': fields.String | |||
| } | |||
| feedback_fields = { | |||
| 'rating': fields.String, | |||
| 'content': fields.String, | |||
| 'from_source': fields.String, | |||
| 'from_end_user_id': fields.String, | |||
| 'from_account': fields.Nested(account_fields, allow_null=True), | |||
| } | |||
| annotation_fields = { | |||
| 'content': fields.String, | |||
| 'account': fields.Nested(account_fields, allow_null=True), | |||
| 'created_at': TimestampField | |||
| } | |||
| message_detail_fields = { | |||
| 'id': fields.String, | |||
| 'conversation_id': fields.String, | |||
| 'inputs': fields.Raw, | |||
| 'query': fields.String, | |||
| 'message': fields.Raw, | |||
| 'message_tokens': fields.Integer, | |||
| 'answer': fields.String, | |||
| 'answer_tokens': fields.Integer, | |||
| 'provider_response_latency': fields.Float, | |||
| 'from_source': fields.String, | |||
| 'from_end_user_id': fields.String, | |||
| 'from_account_id': fields.String, | |||
| 'feedbacks': fields.List(fields.Nested(feedback_fields)), | |||
| 'annotation': fields.Nested(annotation_fields, allow_null=True), | |||
| 'created_at': TimestampField | |||
| } | |||
| class ChatMessageListApi(Resource): | |||
| message_infinite_scroll_pagination_fields = { | |||
| @@ -8,26 +8,11 @@ from controllers.console import api | |||
| from controllers.console.app import _get_app | |||
| from controllers.console.setup import setup_required | |||
| from controllers.console.wraps import account_initialization_required | |||
| from fields.app_fields import app_site_fields | |||
| from libs.helper import supported_language | |||
| from extensions.ext_database import db | |||
| from models.model import Site | |||
| app_site_fields = { | |||
| 'app_id': fields.String, | |||
| 'access_token': fields.String(attribute='code'), | |||
| 'code': fields.String, | |||
| 'title': fields.String, | |||
| 'icon': fields.String, | |||
| 'icon_background': fields.String, | |||
| 'description': fields.String, | |||
| 'default_language': fields.String, | |||
| 'customize_domain': fields.String, | |||
| 'copyright': fields.String, | |||
| 'privacy_policy': fields.String, | |||
| 'customize_token_strategy': fields.String, | |||
| 'prompt_public': fields.Boolean | |||
| } | |||
| def parse_app_site_args(): | |||
| parser = reqparse.RequestParser() | |||
| @@ -14,6 +14,7 @@ from controllers.console.wraps import account_initialization_required | |||
| from core.data_loader.loader.notion import NotionLoader | |||
| from core.indexing_runner import IndexingRunner | |||
| from extensions.ext_database import db | |||
| from fields.data_source_fields import integrate_notion_info_list_fields, integrate_list_fields | |||
| from libs.helper import TimestampField | |||
| from models.dataset import Document | |||
| from models.source import DataSourceBinding | |||
| @@ -24,37 +25,6 @@ cache = TTLCache(maxsize=None, ttl=30) | |||
| class DataSourceApi(Resource): | |||
| integrate_icon_fields = { | |||
| 'type': fields.String, | |||
| 'url': fields.String, | |||
| 'emoji': fields.String | |||
| } | |||
| integrate_page_fields = { | |||
| 'page_name': fields.String, | |||
| 'page_id': fields.String, | |||
| 'page_icon': fields.Nested(integrate_icon_fields, allow_null=True), | |||
| 'parent_id': fields.String, | |||
| 'type': fields.String | |||
| } | |||
| integrate_workspace_fields = { | |||
| 'workspace_name': fields.String, | |||
| 'workspace_id': fields.String, | |||
| 'workspace_icon': fields.String, | |||
| 'pages': fields.List(fields.Nested(integrate_page_fields)), | |||
| 'total': fields.Integer | |||
| } | |||
| integrate_fields = { | |||
| 'id': fields.String, | |||
| 'provider': fields.String, | |||
| 'created_at': TimestampField, | |||
| 'is_bound': fields.Boolean, | |||
| 'disabled': fields.Boolean, | |||
| 'link': fields.String, | |||
| 'source_info': fields.Nested(integrate_workspace_fields) | |||
| } | |||
| integrate_list_fields = { | |||
| 'data': fields.List(fields.Nested(integrate_fields)), | |||
| } | |||
| @setup_required | |||
| @login_required | |||
| @@ -131,28 +101,6 @@ class DataSourceApi(Resource): | |||
| class DataSourceNotionListApi(Resource): | |||
| integrate_icon_fields = { | |||
| 'type': fields.String, | |||
| 'url': fields.String, | |||
| 'emoji': fields.String | |||
| } | |||
| integrate_page_fields = { | |||
| 'page_name': fields.String, | |||
| 'page_id': fields.String, | |||
| 'page_icon': fields.Nested(integrate_icon_fields, allow_null=True), | |||
| 'is_bound': fields.Boolean, | |||
| 'parent_id': fields.String, | |||
| 'type': fields.String | |||
| } | |||
| integrate_workspace_fields = { | |||
| 'workspace_name': fields.String, | |||
| 'workspace_id': fields.String, | |||
| 'workspace_icon': fields.String, | |||
| 'pages': fields.List(fields.Nested(integrate_page_fields)) | |||
| } | |||
| integrate_notion_info_list_fields = { | |||
| 'notion_info': fields.List(fields.Nested(integrate_workspace_fields)), | |||
| } | |||
| @setup_required | |||
| @login_required | |||
| @@ -1,6 +1,9 @@ | |||
| # -*- coding:utf-8 -*- | |||
| from flask import request | |||
| import flask_restful | |||
| from flask import request, current_app | |||
| from flask_login import current_user | |||
| from controllers.console.apikey import api_key_list, api_key_fields | |||
| from core.login.login import login_required | |||
| from flask_restful import Resource, reqparse, fields, marshal, marshal_with | |||
| from werkzeug.exceptions import NotFound, Forbidden | |||
| @@ -12,45 +15,16 @@ from controllers.console.setup import setup_required | |||
| from controllers.console.wraps import account_initialization_required | |||
| from core.indexing_runner import IndexingRunner | |||
| from core.model_providers.error import LLMBadRequestError, ProviderTokenNotInitError | |||
| from core.model_providers.model_factory import ModelFactory | |||
| from core.model_providers.models.entity.model_params import ModelType | |||
| from libs.helper import TimestampField | |||
| from fields.app_fields import related_app_list | |||
| from fields.dataset_fields import dataset_detail_fields, dataset_query_detail_fields | |||
| from fields.document_fields import document_status_fields | |||
| from extensions.ext_database import db | |||
| from models.dataset import DocumentSegment, Document | |||
| from models.model import UploadFile | |||
| from models.model import UploadFile, ApiToken | |||
| from services.dataset_service import DatasetService, DocumentService | |||
| from services.provider_service import ProviderService | |||
| dataset_detail_fields = { | |||
| 'id': fields.String, | |||
| 'name': fields.String, | |||
| 'description': fields.String, | |||
| 'provider': fields.String, | |||
| 'permission': fields.String, | |||
| 'data_source_type': fields.String, | |||
| 'indexing_technique': fields.String, | |||
| 'app_count': fields.Integer, | |||
| 'document_count': fields.Integer, | |||
| 'word_count': fields.Integer, | |||
| 'created_by': fields.String, | |||
| 'created_at': TimestampField, | |||
| 'updated_by': fields.String, | |||
| 'updated_at': TimestampField, | |||
| 'embedding_model': fields.String, | |||
| 'embedding_model_provider': fields.String, | |||
| 'embedding_available': fields.Boolean | |||
| } | |||
| dataset_query_detail_fields = { | |||
| "id": fields.String, | |||
| "content": fields.String, | |||
| "source": fields.String, | |||
| "source_app_id": fields.String, | |||
| "created_by_role": fields.String, | |||
| "created_by": fields.String, | |||
| "created_at": TimestampField | |||
| } | |||
| def _validate_name(name): | |||
| if not name or len(name) < 1 or len(name) > 40: | |||
| @@ -82,7 +56,8 @@ class DatasetListApi(Resource): | |||
| # check embedding setting | |||
| provider_service = ProviderService() | |||
| valid_model_list = provider_service.get_valid_model_list(current_user.current_tenant_id, ModelType.EMBEDDINGS.value) | |||
| valid_model_list = provider_service.get_valid_model_list(current_user.current_tenant_id, | |||
| ModelType.EMBEDDINGS.value) | |||
| # if len(valid_model_list) == 0: | |||
| # raise ProviderNotInitializeError( | |||
| # f"No Embedding Model available. Please configure a valid provider " | |||
| @@ -157,7 +132,8 @@ class DatasetApi(Resource): | |||
| # check embedding setting | |||
| provider_service = ProviderService() | |||
| # get valid model list | |||
| valid_model_list = provider_service.get_valid_model_list(current_user.current_tenant_id, ModelType.EMBEDDINGS.value) | |||
| valid_model_list = provider_service.get_valid_model_list(current_user.current_tenant_id, | |||
| ModelType.EMBEDDINGS.value) | |||
| model_names = [] | |||
| for valid_model in valid_model_list: | |||
| model_names.append(f"{valid_model['model_name']}:{valid_model['model_provider']['provider_name']}") | |||
| @@ -271,7 +247,8 @@ class DatasetIndexingEstimateApi(Resource): | |||
| parser.add_argument('indexing_technique', type=str, required=True, nullable=True, location='json') | |||
| parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json') | |||
| parser.add_argument('dataset_id', type=str, required=False, nullable=False, location='json') | |||
| parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False, location='json') | |||
| parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False, | |||
| location='json') | |||
| args = parser.parse_args() | |||
| # validate args | |||
| DocumentService.estimate_args_validate(args) | |||
| @@ -320,18 +297,6 @@ class DatasetIndexingEstimateApi(Resource): | |||
| class DatasetRelatedAppListApi(Resource): | |||
| app_detail_kernel_fields = { | |||
| 'id': fields.String, | |||
| 'name': fields.String, | |||
| 'mode': fields.String, | |||
| 'icon': fields.String, | |||
| 'icon_background': fields.String, | |||
| } | |||
| related_app_list = { | |||
| 'data': fields.List(fields.Nested(app_detail_kernel_fields)), | |||
| 'total': fields.Integer, | |||
| } | |||
| @setup_required | |||
| @login_required | |||
| @@ -363,24 +328,6 @@ class DatasetRelatedAppListApi(Resource): | |||
| class DatasetIndexingStatusApi(Resource): | |||
| document_status_fields = { | |||
| 'id': fields.String, | |||
| 'indexing_status': fields.String, | |||
| 'processing_started_at': TimestampField, | |||
| 'parsing_completed_at': TimestampField, | |||
| 'cleaning_completed_at': TimestampField, | |||
| 'splitting_completed_at': TimestampField, | |||
| 'completed_at': TimestampField, | |||
| 'paused_at': TimestampField, | |||
| 'error': fields.String, | |||
| 'stopped_at': TimestampField, | |||
| 'completed_segments': fields.Integer, | |||
| 'total_segments': fields.Integer, | |||
| } | |||
| document_status_fields_list = { | |||
| 'data': fields.List(fields.Nested(document_status_fields)) | |||
| } | |||
| @setup_required | |||
| @login_required | |||
| @@ -400,16 +347,97 @@ class DatasetIndexingStatusApi(Resource): | |||
| DocumentSegment.status != 're_segment').count() | |||
| document.completed_segments = completed_segments | |||
| document.total_segments = total_segments | |||
| documents_status.append(marshal(document, self.document_status_fields)) | |||
| documents_status.append(marshal(document, document_status_fields)) | |||
| data = { | |||
| 'data': documents_status | |||
| } | |||
| return data | |||
| class DatasetApiKeyApi(Resource): | |||
| max_keys = 10 | |||
| token_prefix = 'dataset-' | |||
| resource_type = 'dataset' | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @marshal_with(api_key_list) | |||
| def get(self): | |||
| keys = db.session.query(ApiToken). \ | |||
| filter(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id). \ | |||
| all() | |||
| return {"items": keys} | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @marshal_with(api_key_fields) | |||
| def post(self): | |||
| # The role of the current user in the ta table must be admin or owner | |||
| if current_user.current_tenant.current_role not in ['admin', 'owner']: | |||
| raise Forbidden() | |||
| current_key_count = db.session.query(ApiToken). \ | |||
| filter(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id). \ | |||
| count() | |||
| if current_key_count >= self.max_keys: | |||
| flask_restful.abort( | |||
| 400, | |||
| message=f"Cannot create more than {self.max_keys} API keys for this resource type.", | |||
| code='max_keys_exceeded' | |||
| ) | |||
| key = ApiToken.generate_api_key(self.token_prefix, 24) | |||
| api_token = ApiToken() | |||
| api_token.tenant_id = current_user.current_tenant_id | |||
| api_token.token = key | |||
| api_token.type = self.resource_type | |||
| db.session.add(api_token) | |||
| db.session.commit() | |||
| return api_token, 200 | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| def delete(self, api_key_id): | |||
| api_key_id = str(api_key_id) | |||
| # The role of the current user in the ta table must be admin or owner | |||
| if current_user.current_tenant.current_role not in ['admin', 'owner']: | |||
| raise Forbidden() | |||
| key = db.session.query(ApiToken). \ | |||
| filter(ApiToken.tenant_id == current_user.current_tenant_id, ApiToken.type == self.resource_type, | |||
| ApiToken.id == api_key_id). \ | |||
| first() | |||
| if key is None: | |||
| flask_restful.abort(404, message='API key not found') | |||
| db.session.query(ApiToken).filter(ApiToken.id == api_key_id).delete() | |||
| db.session.commit() | |||
| return {'result': 'success'}, 204 | |||
| class DatasetApiBaseUrlApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| def get(self): | |||
| return { | |||
| 'api_base_url': (current_app.config['SERVICE_API_URL'] if current_app.config['SERVICE_API_URL'] | |||
| else request.host_url.rstrip('/')) + '/v1' | |||
| } | |||
| api.add_resource(DatasetListApi, '/datasets') | |||
| api.add_resource(DatasetApi, '/datasets/<uuid:dataset_id>') | |||
| api.add_resource(DatasetQueryApi, '/datasets/<uuid:dataset_id>/queries') | |||
| api.add_resource(DatasetIndexingEstimateApi, '/datasets/indexing-estimate') | |||
| api.add_resource(DatasetRelatedAppListApi, '/datasets/<uuid:dataset_id>/related-apps') | |||
| api.add_resource(DatasetIndexingStatusApi, '/datasets/<uuid:dataset_id>/indexing-status') | |||
| api.add_resource(DatasetApiKeyApi, '/datasets/api-keys') | |||
| api.add_resource(DatasetApiBaseUrlApi, '/datasets/api-base-info') | |||
| @@ -23,6 +23,8 @@ from core.model_providers.error import ProviderTokenNotInitError, QuotaExceededE | |||
| LLMBadRequestError | |||
| from core.model_providers.model_factory import ModelFactory | |||
| from extensions.ext_redis import redis_client | |||
| from fields.document_fields import document_with_segments_fields, document_fields, \ | |||
| dataset_and_document_fields, document_status_fields | |||
| from libs.helper import TimestampField | |||
| from extensions.ext_database import db | |||
| from models.dataset import DatasetProcessRule, Dataset | |||
| @@ -32,64 +34,6 @@ from services.dataset_service import DocumentService, DatasetService | |||
| from tasks.add_document_to_index_task import add_document_to_index_task | |||
| from tasks.remove_document_from_index_task import remove_document_from_index_task | |||
| dataset_fields = { | |||
| 'id': fields.String, | |||
| 'name': fields.String, | |||
| 'description': fields.String, | |||
| 'permission': fields.String, | |||
| 'data_source_type': fields.String, | |||
| 'indexing_technique': fields.String, | |||
| 'created_by': fields.String, | |||
| 'created_at': TimestampField, | |||
| } | |||
| document_fields = { | |||
| 'id': fields.String, | |||
| 'position': fields.Integer, | |||
| 'data_source_type': fields.String, | |||
| 'data_source_info': fields.Raw(attribute='data_source_info_dict'), | |||
| 'dataset_process_rule_id': fields.String, | |||
| 'name': fields.String, | |||
| 'created_from': fields.String, | |||
| 'created_by': fields.String, | |||
| 'created_at': TimestampField, | |||
| 'tokens': fields.Integer, | |||
| 'indexing_status': fields.String, | |||
| 'error': fields.String, | |||
| 'enabled': fields.Boolean, | |||
| 'disabled_at': TimestampField, | |||
| 'disabled_by': fields.String, | |||
| 'archived': fields.Boolean, | |||
| 'display_status': fields.String, | |||
| 'word_count': fields.Integer, | |||
| 'hit_count': fields.Integer, | |||
| 'doc_form': fields.String, | |||
| } | |||
| document_with_segments_fields = { | |||
| 'id': fields.String, | |||
| 'position': fields.Integer, | |||
| 'data_source_type': fields.String, | |||
| 'data_source_info': fields.Raw(attribute='data_source_info_dict'), | |||
| 'dataset_process_rule_id': fields.String, | |||
| 'name': fields.String, | |||
| 'created_from': fields.String, | |||
| 'created_by': fields.String, | |||
| 'created_at': TimestampField, | |||
| 'tokens': fields.Integer, | |||
| 'indexing_status': fields.String, | |||
| 'error': fields.String, | |||
| 'enabled': fields.Boolean, | |||
| 'disabled_at': TimestampField, | |||
| 'disabled_by': fields.String, | |||
| 'archived': fields.Boolean, | |||
| 'display_status': fields.String, | |||
| 'word_count': fields.Integer, | |||
| 'hit_count': fields.Integer, | |||
| 'completed_segments': fields.Integer, | |||
| 'total_segments': fields.Integer | |||
| } | |||
| class DocumentResource(Resource): | |||
| def get_document(self, dataset_id: str, document_id: str) -> Document: | |||
| @@ -303,11 +247,6 @@ class DatasetDocumentListApi(Resource): | |||
| class DatasetInitApi(Resource): | |||
| dataset_and_document_fields = { | |||
| 'dataset': fields.Nested(dataset_fields), | |||
| 'documents': fields.List(fields.Nested(document_fields)), | |||
| 'batch': fields.String | |||
| } | |||
| @setup_required | |||
| @login_required | |||
| @@ -504,24 +443,6 @@ class DocumentBatchIndexingEstimateApi(DocumentResource): | |||
| class DocumentBatchIndexingStatusApi(DocumentResource): | |||
| document_status_fields = { | |||
| 'id': fields.String, | |||
| 'indexing_status': fields.String, | |||
| 'processing_started_at': TimestampField, | |||
| 'parsing_completed_at': TimestampField, | |||
| 'cleaning_completed_at': TimestampField, | |||
| 'splitting_completed_at': TimestampField, | |||
| 'completed_at': TimestampField, | |||
| 'paused_at': TimestampField, | |||
| 'error': fields.String, | |||
| 'stopped_at': TimestampField, | |||
| 'completed_segments': fields.Integer, | |||
| 'total_segments': fields.Integer, | |||
| } | |||
| document_status_fields_list = { | |||
| 'data': fields.List(fields.Nested(document_status_fields)) | |||
| } | |||
| @setup_required | |||
| @login_required | |||
| @@ -541,7 +462,7 @@ class DocumentBatchIndexingStatusApi(DocumentResource): | |||
| document.total_segments = total_segments | |||
| if document.is_paused: | |||
| document.indexing_status = 'paused' | |||
| documents_status.append(marshal(document, self.document_status_fields)) | |||
| documents_status.append(marshal(document, document_status_fields)) | |||
| data = { | |||
| 'data': documents_status | |||
| } | |||
| @@ -549,20 +470,6 @@ class DocumentBatchIndexingStatusApi(DocumentResource): | |||
| class DocumentIndexingStatusApi(DocumentResource): | |||
| document_status_fields = { | |||
| 'id': fields.String, | |||
| 'indexing_status': fields.String, | |||
| 'processing_started_at': TimestampField, | |||
| 'parsing_completed_at': TimestampField, | |||
| 'cleaning_completed_at': TimestampField, | |||
| 'splitting_completed_at': TimestampField, | |||
| 'completed_at': TimestampField, | |||
| 'paused_at': TimestampField, | |||
| 'error': fields.String, | |||
| 'stopped_at': TimestampField, | |||
| 'completed_segments': fields.Integer, | |||
| 'total_segments': fields.Integer, | |||
| } | |||
| @setup_required | |||
| @login_required | |||
| @@ -586,7 +493,7 @@ class DocumentIndexingStatusApi(DocumentResource): | |||
| document.total_segments = total_segments | |||
| if document.is_paused: | |||
| document.indexing_status = 'paused' | |||
| return marshal(document, self.document_status_fields) | |||
| return marshal(document, document_status_fields) | |||
| class DocumentDetailApi(DocumentResource): | |||
| @@ -3,7 +3,7 @@ import uuid | |||
| from datetime import datetime | |||
| from flask import request | |||
| from flask_login import current_user | |||
| from flask_restful import Resource, reqparse, fields, marshal | |||
| from flask_restful import Resource, reqparse, marshal | |||
| from werkzeug.exceptions import NotFound, Forbidden | |||
| import services | |||
| @@ -17,6 +17,7 @@ from core.model_providers.model_factory import ModelFactory | |||
| from core.login.login import login_required | |||
| from extensions.ext_database import db | |||
| from extensions.ext_redis import redis_client | |||
| from fields.segment_fields import segment_fields | |||
| from models.dataset import DocumentSegment | |||
| from libs.helper import TimestampField | |||
| @@ -26,36 +27,6 @@ from tasks.disable_segment_from_index_task import disable_segment_from_index_tas | |||
| from tasks.batch_create_segment_to_index_task import batch_create_segment_to_index_task | |||
| import pandas as pd | |||
| segment_fields = { | |||
| 'id': fields.String, | |||
| 'position': fields.Integer, | |||
| 'document_id': fields.String, | |||
| 'content': fields.String, | |||
| 'answer': fields.String, | |||
| 'word_count': fields.Integer, | |||
| 'tokens': fields.Integer, | |||
| 'keywords': fields.List(fields.String), | |||
| 'index_node_id': fields.String, | |||
| 'index_node_hash': fields.String, | |||
| 'hit_count': fields.Integer, | |||
| 'enabled': fields.Boolean, | |||
| 'disabled_at': TimestampField, | |||
| 'disabled_by': fields.String, | |||
| 'status': fields.String, | |||
| 'created_by': fields.String, | |||
| 'created_at': TimestampField, | |||
| 'indexing_at': TimestampField, | |||
| 'completed_at': TimestampField, | |||
| 'error': fields.String, | |||
| 'stopped_at': TimestampField | |||
| } | |||
| segment_list_response = { | |||
| 'data': fields.List(fields.Nested(segment_fields)), | |||
| 'has_more': fields.Boolean, | |||
| 'limit': fields.Integer | |||
| } | |||
| class DatasetDocumentSegmentListApi(Resource): | |||
| @setup_required | |||
| @@ -1,28 +1,19 @@ | |||
| import datetime | |||
| import hashlib | |||
| import tempfile | |||
| import chardet | |||
| import time | |||
| import uuid | |||
| from pathlib import Path | |||
| from cachetools import TTLCache | |||
| from flask import request, current_app | |||
| from flask_login import current_user | |||
| import services | |||
| from core.login.login import login_required | |||
| from flask_restful import Resource, marshal_with, fields | |||
| from werkzeug.exceptions import NotFound | |||
| from controllers.console import api | |||
| from controllers.console.datasets.error import NoFileUploadedError, TooManyFilesError, FileTooLargeError, \ | |||
| UnsupportedFileTypeError | |||
| from controllers.console.setup import setup_required | |||
| from controllers.console.wraps import account_initialization_required | |||
| from core.data_loader.file_extractor import FileExtractor | |||
| from extensions.ext_storage import storage | |||
| from libs.helper import TimestampField | |||
| from extensions.ext_database import db | |||
| from models.model import UploadFile | |||
| from fields.file_fields import upload_config_fields, file_fields | |||
| from services.file_service import FileService | |||
| cache = TTLCache(maxsize=None, ttl=30) | |||
| @@ -31,10 +22,6 @@ PREVIEW_WORDS_LIMIT = 3000 | |||
| class FileApi(Resource): | |||
| upload_config_fields = { | |||
| 'file_size_limit': fields.Integer, | |||
| 'batch_count_limit': fields.Integer | |||
| } | |||
| @setup_required | |||
| @login_required | |||
| @@ -48,16 +35,6 @@ class FileApi(Resource): | |||
| 'batch_count_limit': batch_count_limit | |||
| }, 200 | |||
| file_fields = { | |||
| 'id': fields.String, | |||
| 'name': fields.String, | |||
| 'size': fields.Integer, | |||
| 'extension': fields.String, | |||
| 'mime_type': fields.String, | |||
| 'created_by': fields.String, | |||
| 'created_at': TimestampField, | |||
| } | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @@ -73,45 +50,13 @@ class FileApi(Resource): | |||
| if len(request.files) > 1: | |||
| raise TooManyFilesError() | |||
| file_content = file.read() | |||
| file_size = len(file_content) | |||
| file_size_limit = current_app.config.get("UPLOAD_FILE_SIZE_LIMIT") * 1024 * 1024 | |||
| if file_size > file_size_limit: | |||
| message = "({file_size} > {file_size_limit})" | |||
| raise FileTooLargeError(message) | |||
| extension = file.filename.split('.')[-1] | |||
| if extension.lower() not in ALLOWED_EXTENSIONS: | |||
| try: | |||
| upload_file = FileService.upload_file(file) | |||
| except services.errors.file.FileTooLargeError as file_too_large_error: | |||
| raise FileTooLargeError(file_too_large_error.description) | |||
| except services.errors.file.UnsupportedFileTypeError: | |||
| raise UnsupportedFileTypeError() | |||
| # user uuid as file name | |||
| file_uuid = str(uuid.uuid4()) | |||
| file_key = 'upload_files/' + current_user.current_tenant_id + '/' + file_uuid + '.' + extension | |||
| # save file to storage | |||
| storage.save(file_key, file_content) | |||
| # save file to db | |||
| config = current_app.config | |||
| upload_file = UploadFile( | |||
| tenant_id=current_user.current_tenant_id, | |||
| storage_type=config['STORAGE_TYPE'], | |||
| key=file_key, | |||
| name=file.filename, | |||
| size=file_size, | |||
| extension=extension, | |||
| mime_type=file.mimetype, | |||
| created_by=current_user.id, | |||
| created_at=datetime.datetime.utcnow(), | |||
| used=False, | |||
| hash=hashlib.sha3_256(file_content).hexdigest() | |||
| ) | |||
| db.session.add(upload_file) | |||
| db.session.commit() | |||
| return upload_file, 201 | |||
| @@ -121,26 +66,7 @@ class FilePreviewApi(Resource): | |||
| @account_initialization_required | |||
| def get(self, file_id): | |||
| file_id = str(file_id) | |||
| key = file_id + request.path | |||
| cached_response = cache.get(key) | |||
| if cached_response and time.time() - cached_response['timestamp'] < cache.ttl: | |||
| return cached_response['response'] | |||
| upload_file = db.session.query(UploadFile) \ | |||
| .filter(UploadFile.id == file_id) \ | |||
| .first() | |||
| if not upload_file: | |||
| raise NotFound("File not found") | |||
| # extract text from file | |||
| extension = upload_file.extension | |||
| if extension.lower() not in ALLOWED_EXTENSIONS: | |||
| raise UnsupportedFileTypeError() | |||
| text = FileExtractor.load(upload_file, return_text=True) | |||
| text = text[0:PREVIEW_WORDS_LIMIT] if text else '' | |||
| text = FileService.get_file_preview(file_id) | |||
| return {'content': text} | |||
| @@ -2,7 +2,7 @@ import logging | |||
| from flask_login import current_user | |||
| from core.login.login import login_required | |||
| from flask_restful import Resource, reqparse, marshal, fields | |||
| from flask_restful import Resource, reqparse, marshal | |||
| from werkzeug.exceptions import InternalServerError, NotFound, Forbidden | |||
| import services | |||
| @@ -14,48 +14,10 @@ from controllers.console.setup import setup_required | |||
| from controllers.console.wraps import account_initialization_required | |||
| from core.model_providers.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError, \ | |||
| LLMBadRequestError | |||
| from libs.helper import TimestampField | |||
| from fields.hit_testing_fields import hit_testing_record_fields | |||
| from services.dataset_service import DatasetService | |||
| from services.hit_testing_service import HitTestingService | |||
| document_fields = { | |||
| 'id': fields.String, | |||
| 'data_source_type': fields.String, | |||
| 'name': fields.String, | |||
| 'doc_type': fields.String, | |||
| } | |||
| segment_fields = { | |||
| 'id': fields.String, | |||
| 'position': fields.Integer, | |||
| 'document_id': fields.String, | |||
| 'content': fields.String, | |||
| 'answer': fields.String, | |||
| 'word_count': fields.Integer, | |||
| 'tokens': fields.Integer, | |||
| 'keywords': fields.List(fields.String), | |||
| 'index_node_id': fields.String, | |||
| 'index_node_hash': fields.String, | |||
| 'hit_count': fields.Integer, | |||
| 'enabled': fields.Boolean, | |||
| 'disabled_at': TimestampField, | |||
| 'disabled_by': fields.String, | |||
| 'status': fields.String, | |||
| 'created_by': fields.String, | |||
| 'created_at': TimestampField, | |||
| 'indexing_at': TimestampField, | |||
| 'completed_at': TimestampField, | |||
| 'error': fields.String, | |||
| 'stopped_at': TimestampField, | |||
| 'document': fields.Nested(document_fields), | |||
| } | |||
| hit_testing_record_fields = { | |||
| 'segment': fields.Nested(segment_fields), | |||
| 'score': fields.Float, | |||
| 'tsne_position': fields.Raw | |||
| } | |||
| class HitTestingApi(Resource): | |||
| @@ -7,26 +7,12 @@ from werkzeug.exceptions import NotFound | |||
| from controllers.console import api | |||
| from controllers.console.explore.error import NotChatAppError | |||
| from controllers.console.explore.wraps import InstalledAppResource | |||
| from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields | |||
| from libs.helper import TimestampField, uuid_value | |||
| from services.conversation_service import ConversationService | |||
| from services.errors.conversation import LastConversationNotExistsError, ConversationNotExistsError | |||
| from services.web_conversation_service import WebConversationService | |||
| conversation_fields = { | |||
| 'id': fields.String, | |||
| 'name': fields.String, | |||
| 'inputs': fields.Raw, | |||
| 'status': fields.String, | |||
| 'introduction': fields.String, | |||
| 'created_at': TimestampField | |||
| } | |||
| conversation_infinite_scroll_pagination_fields = { | |||
| 'limit': fields.Integer, | |||
| 'has_more': fields.Boolean, | |||
| 'data': fields.List(fields.Nested(conversation_fields)) | |||
| } | |||
| class ConversationListApi(InstalledAppResource): | |||
| @@ -76,7 +62,7 @@ class ConversationApi(InstalledAppResource): | |||
| class ConversationRenameApi(InstalledAppResource): | |||
| @marshal_with(conversation_fields) | |||
| @marshal_with(simple_conversation_fields) | |||
| def post(self, installed_app, c_id): | |||
| app_model = installed_app.app | |||
| if app_model.mode != 'chat': | |||
| @@ -11,32 +11,11 @@ from controllers.console import api | |||
| from controllers.console.explore.wraps import InstalledAppResource | |||
| from controllers.console.wraps import account_initialization_required | |||
| from extensions.ext_database import db | |||
| from fields.installed_app_fields import installed_app_list_fields | |||
| from libs.helper import TimestampField | |||
| from models.model import App, InstalledApp, RecommendedApp | |||
| from services.account_service import TenantService | |||
| app_fields = { | |||
| 'id': fields.String, | |||
| 'name': fields.String, | |||
| 'mode': fields.String, | |||
| 'icon': fields.String, | |||
| 'icon_background': fields.String | |||
| } | |||
| installed_app_fields = { | |||
| 'id': fields.String, | |||
| 'app': fields.Nested(app_fields), | |||
| 'app_owner_tenant_id': fields.String, | |||
| 'is_pinned': fields.Boolean, | |||
| 'last_used_at': TimestampField, | |||
| 'editable': fields.Boolean, | |||
| 'uninstallable': fields.Boolean, | |||
| } | |||
| installed_app_list_fields = { | |||
| 'installed_apps': fields.List(fields.Nested(installed_app_fields)) | |||
| } | |||
| class InstalledAppsListApi(Resource): | |||
| @login_required | |||
| @@ -17,6 +17,7 @@ from controllers.console.explore.error import NotCompletionAppError, AppSuggeste | |||
| from controllers.console.explore.wraps import InstalledAppResource | |||
| from core.model_providers.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \ | |||
| ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError | |||
| from fields.message_fields import message_infinite_scroll_pagination_fields | |||
| from libs.helper import uuid_value, TimestampField | |||
| from services.completion_service import CompletionService | |||
| from services.errors.app import MoreLikeThisDisabledError | |||
| @@ -26,45 +27,6 @@ from services.message_service import MessageService | |||
| class MessageListApi(InstalledAppResource): | |||
| feedback_fields = { | |||
| 'rating': fields.String | |||
| } | |||
| retriever_resource_fields = { | |||
| 'id': fields.String, | |||
| 'message_id': fields.String, | |||
| 'position': fields.Integer, | |||
| 'dataset_id': fields.String, | |||
| 'dataset_name': fields.String, | |||
| 'document_id': fields.String, | |||
| 'document_name': fields.String, | |||
| 'data_source_type': fields.String, | |||
| 'segment_id': fields.String, | |||
| 'score': fields.Float, | |||
| 'hit_count': fields.Integer, | |||
| 'word_count': fields.Integer, | |||
| 'segment_position': fields.Integer, | |||
| 'index_node_hash': fields.String, | |||
| 'content': fields.String, | |||
| 'created_at': TimestampField | |||
| } | |||
| message_fields = { | |||
| 'id': fields.String, | |||
| 'conversation_id': fields.String, | |||
| 'inputs': fields.Raw, | |||
| 'query': fields.String, | |||
| 'answer': fields.String, | |||
| 'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True), | |||
| 'retriever_resources': fields.List(fields.Nested(retriever_resource_fields)), | |||
| 'created_at': TimestampField | |||
| } | |||
| message_infinite_scroll_pagination_fields = { | |||
| 'limit': fields.Integer, | |||
| 'has_more': fields.Boolean, | |||
| 'data': fields.List(fields.Nested(message_fields)) | |||
| } | |||
| @marshal_with(message_infinite_scroll_pagination_fields) | |||
| def get(self, installed_app): | |||
| @@ -6,31 +6,17 @@ from werkzeug.exceptions import NotFound | |||
| from controllers.console import api | |||
| from controllers.console.universal_chat.wraps import UniversalChatResource | |||
| from fields.conversation_fields import conversation_with_model_config_infinite_scroll_pagination_fields, \ | |||
| conversation_with_model_config_fields | |||
| from libs.helper import TimestampField, uuid_value | |||
| from services.conversation_service import ConversationService | |||
| from services.errors.conversation import LastConversationNotExistsError, ConversationNotExistsError | |||
| from services.web_conversation_service import WebConversationService | |||
| conversation_fields = { | |||
| 'id': fields.String, | |||
| 'name': fields.String, | |||
| 'inputs': fields.Raw, | |||
| 'status': fields.String, | |||
| 'introduction': fields.String, | |||
| 'created_at': TimestampField, | |||
| 'model_config': fields.Raw, | |||
| } | |||
| conversation_infinite_scroll_pagination_fields = { | |||
| 'limit': fields.Integer, | |||
| 'has_more': fields.Boolean, | |||
| 'data': fields.List(fields.Nested(conversation_fields)) | |||
| } | |||
| class UniversalChatConversationListApi(UniversalChatResource): | |||
| @marshal_with(conversation_infinite_scroll_pagination_fields) | |||
| @marshal_with(conversation_with_model_config_infinite_scroll_pagination_fields) | |||
| def get(self, universal_app): | |||
| app_model = universal_app | |||
| @@ -73,7 +59,7 @@ class UniversalChatConversationApi(UniversalChatResource): | |||
| class UniversalChatConversationRenameApi(UniversalChatResource): | |||
| @marshal_with(conversation_fields) | |||
| @marshal_with(conversation_with_model_config_fields) | |||
| def post(self, universal_app, c_id): | |||
| app_model = universal_app | |||
| conversation_id = str(c_id) | |||
| @@ -9,4 +9,4 @@ api = ExternalApi(bp) | |||
| from .app import completion, app, conversation, message, audio | |||
| from .dataset import document | |||
| from .dataset import document, segment, dataset | |||
| @@ -8,25 +8,11 @@ from controllers.service_api import api | |||
| from controllers.service_api.app import create_or_update_end_user_for_user_id | |||
| from controllers.service_api.app.error import NotChatAppError | |||
| from controllers.service_api.wraps import AppApiResource | |||
| from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields | |||
| from libs.helper import TimestampField, uuid_value | |||
| import services | |||
| from services.conversation_service import ConversationService | |||
| conversation_fields = { | |||
| 'id': fields.String, | |||
| 'name': fields.String, | |||
| 'inputs': fields.Raw, | |||
| 'status': fields.String, | |||
| 'introduction': fields.String, | |||
| 'created_at': TimestampField | |||
| } | |||
| conversation_infinite_scroll_pagination_fields = { | |||
| 'limit': fields.Integer, | |||
| 'has_more': fields.Boolean, | |||
| 'data': fields.List(fields.Nested(conversation_fields)) | |||
| } | |||
| class ConversationApi(AppApiResource): | |||
| @@ -50,7 +36,7 @@ class ConversationApi(AppApiResource): | |||
| raise NotFound("Last Conversation Not Exists.") | |||
| class ConversationDetailApi(AppApiResource): | |||
| @marshal_with(conversation_fields) | |||
| @marshal_with(simple_conversation_fields) | |||
| def delete(self, app_model, end_user, c_id): | |||
| if app_model.mode != 'chat': | |||
| raise NotChatAppError() | |||
| @@ -70,7 +56,7 @@ class ConversationDetailApi(AppApiResource): | |||
| class ConversationRenameApi(AppApiResource): | |||
| @marshal_with(conversation_fields) | |||
| @marshal_with(simple_conversation_fields) | |||
| def post(self, app_model, end_user, c_id): | |||
| if app_model.mode != 'chat': | |||
| raise NotChatAppError() | |||
| @@ -0,0 +1,84 @@ | |||
| from flask import request | |||
| from flask_restful import reqparse, marshal | |||
| import services.dataset_service | |||
| from controllers.service_api import api | |||
| from controllers.service_api.dataset.error import DatasetNameDuplicateError | |||
| from controllers.service_api.wraps import DatasetApiResource | |||
| from core.login.login import current_user | |||
| from core.model_providers.models.entity.model_params import ModelType | |||
| from extensions.ext_database import db | |||
| from fields.dataset_fields import dataset_detail_fields | |||
| from models.account import Account, TenantAccountJoin | |||
| from models.dataset import Dataset | |||
| from services.dataset_service import DatasetService | |||
| from services.provider_service import ProviderService | |||
| def _validate_name(name): | |||
| if not name or len(name) < 1 or len(name) > 40: | |||
| raise ValueError('Name must be between 1 to 40 characters.') | |||
| return name | |||
| class DatasetApi(DatasetApiResource): | |||
| """Resource for get datasets.""" | |||
| def get(self, tenant_id): | |||
| page = request.args.get('page', default=1, type=int) | |||
| limit = request.args.get('limit', default=20, type=int) | |||
| provider = request.args.get('provider', default="vendor") | |||
| datasets, total = DatasetService.get_datasets(page, limit, provider, | |||
| tenant_id, current_user) | |||
| # check embedding setting | |||
| provider_service = ProviderService() | |||
| valid_model_list = provider_service.get_valid_model_list(current_user.current_tenant_id, | |||
| ModelType.EMBEDDINGS.value) | |||
| model_names = [] | |||
| for valid_model in valid_model_list: | |||
| model_names.append(f"{valid_model['model_name']}:{valid_model['model_provider']['provider_name']}") | |||
| data = marshal(datasets, dataset_detail_fields) | |||
| for item in data: | |||
| if item['indexing_technique'] == 'high_quality': | |||
| item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}" | |||
| if item_model in model_names: | |||
| item['embedding_available'] = True | |||
| else: | |||
| item['embedding_available'] = False | |||
| else: | |||
| item['embedding_available'] = True | |||
| response = { | |||
| 'data': data, | |||
| 'has_more': len(datasets) == limit, | |||
| 'limit': limit, | |||
| 'total': total, | |||
| 'page': page | |||
| } | |||
| return response, 200 | |||
| """Resource for datasets.""" | |||
| def post(self, tenant_id): | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('name', nullable=False, required=True, | |||
| help='type is required. Name must be between 1 to 40 characters.', | |||
| type=_validate_name) | |||
| parser.add_argument('indexing_technique', type=str, location='json', | |||
| choices=('high_quality', 'economy'), | |||
| help='Invalid indexing technique.') | |||
| args = parser.parse_args() | |||
| try: | |||
| dataset = DatasetService.create_empty_dataset( | |||
| tenant_id=tenant_id, | |||
| name=args['name'], | |||
| indexing_technique=args['indexing_technique'], | |||
| account=current_user | |||
| ) | |||
| except services.errors.dataset.DatasetNameDuplicateError: | |||
| raise DatasetNameDuplicateError() | |||
| return marshal(dataset, dataset_detail_fields), 200 | |||
| api.add_resource(DatasetApi, '/datasets') | |||
| @@ -1,114 +1,291 @@ | |||
| import datetime | |||
| import json | |||
| import uuid | |||
| from flask import current_app | |||
| from flask_restful import reqparse | |||
| from flask import current_app, request | |||
| from flask_restful import reqparse, marshal | |||
| from sqlalchemy import desc | |||
| from werkzeug.exceptions import NotFound | |||
| import services.dataset_service | |||
| from controllers.service_api import api | |||
| from controllers.service_api.app.error import ProviderNotInitializeError | |||
| from controllers.service_api.dataset.error import ArchivedDocumentImmutableError, DocumentIndexingError, \ | |||
| DatasetNotInitedError | |||
| NoFileUploadedError, TooManyFilesError | |||
| from controllers.service_api.wraps import DatasetApiResource | |||
| from core.login.login import current_user | |||
| from core.model_providers.error import ProviderTokenNotInitError | |||
| from extensions.ext_database import db | |||
| from extensions.ext_storage import storage | |||
| from fields.document_fields import document_fields, document_status_fields | |||
| from models.dataset import Dataset, Document, DocumentSegment | |||
| from models.model import UploadFile | |||
| from services.dataset_service import DocumentService | |||
| from services.file_service import FileService | |||
| class DocumentListApi(DatasetApiResource): | |||
| class DocumentAddByTextApi(DatasetApiResource): | |||
| """Resource for documents.""" | |||
| def post(self, dataset): | |||
| """Create document.""" | |||
| def post(self, tenant_id, dataset_id): | |||
| """Create document by text.""" | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('name', type=str, required=True, nullable=False, location='json') | |||
| parser.add_argument('text', type=str, required=True, nullable=False, location='json') | |||
| parser.add_argument('doc_type', type=str, location='json') | |||
| parser.add_argument('doc_metadata', type=dict, location='json') | |||
| parser.add_argument('process_rule', type=dict, required=False, nullable=True, location='json') | |||
| parser.add_argument('original_document_id', type=str, required=False, location='json') | |||
| parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json') | |||
| parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False, | |||
| location='json') | |||
| parser.add_argument('indexing_technique', type=str, choices=Dataset.INDEXING_TECHNIQUE_LIST, nullable=False, | |||
| location='json') | |||
| args = parser.parse_args() | |||
| dataset_id = str(dataset_id) | |||
| tenant_id = str(tenant_id) | |||
| dataset = db.session.query(Dataset).filter( | |||
| Dataset.tenant_id == tenant_id, | |||
| Dataset.id == dataset_id | |||
| ).first() | |||
| if not dataset: | |||
| raise ValueError('Dataset is not exist.') | |||
| if not dataset.indexing_technique and not args['indexing_technique']: | |||
| raise ValueError('indexing_technique is required.') | |||
| upload_file = FileService.upload_text(args.get('text'), args.get('name')) | |||
| data_source = { | |||
| 'type': 'upload_file', | |||
| 'info_list': { | |||
| 'data_source_type': 'upload_file', | |||
| 'file_info_list': { | |||
| 'file_ids': [upload_file.id] | |||
| } | |||
| } | |||
| } | |||
| args['data_source'] = data_source | |||
| # validate args | |||
| DocumentService.document_create_args_validate(args) | |||
| try: | |||
| documents, batch = DocumentService.save_document_with_dataset_id( | |||
| dataset=dataset, | |||
| document_data=args, | |||
| account=current_user, | |||
| dataset_process_rule=dataset.latest_process_rule if 'process_rule' not in args else None, | |||
| created_from='api' | |||
| ) | |||
| except ProviderTokenNotInitError as ex: | |||
| raise ProviderNotInitializeError(ex.description) | |||
| document = documents[0] | |||
| documents_and_batch_fields = { | |||
| 'document': marshal(document, document_fields), | |||
| 'batch': batch | |||
| } | |||
| return documents_and_batch_fields, 200 | |||
| class DocumentUpdateByTextApi(DatasetApiResource): | |||
| """Resource for update documents.""" | |||
| def post(self, tenant_id, dataset_id, document_id): | |||
| """Update document by text.""" | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('name', type=str, required=False, nullable=True, location='json') | |||
| parser.add_argument('text', type=str, required=False, nullable=True, location='json') | |||
| parser.add_argument('process_rule', type=dict, required=False, nullable=True, location='json') | |||
| parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json') | |||
| parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False, | |||
| location='json') | |||
| args = parser.parse_args() | |||
| dataset_id = str(dataset_id) | |||
| tenant_id = str(tenant_id) | |||
| dataset = db.session.query(Dataset).filter( | |||
| Dataset.tenant_id == tenant_id, | |||
| Dataset.id == dataset_id | |||
| ).first() | |||
| if not dataset.indexing_technique: | |||
| raise DatasetNotInitedError("Dataset indexing technique must be set.") | |||
| doc_type = args.get('doc_type') | |||
| doc_metadata = args.get('doc_metadata') | |||
| if doc_type and doc_type not in DocumentService.DOCUMENT_METADATA_SCHEMA: | |||
| raise ValueError('Invalid doc_type.') | |||
| # user uuid as file name | |||
| file_uuid = str(uuid.uuid4()) | |||
| file_key = 'upload_files/' + dataset.tenant_id + '/' + file_uuid + '.txt' | |||
| # save file to storage | |||
| storage.save(file_key, args.get('text')) | |||
| # save file to db | |||
| config = current_app.config | |||
| upload_file = UploadFile( | |||
| tenant_id=dataset.tenant_id, | |||
| storage_type=config['STORAGE_TYPE'], | |||
| key=file_key, | |||
| name=args.get('name') + '.txt', | |||
| size=len(args.get('text')), | |||
| extension='txt', | |||
| mime_type='text/plain', | |||
| created_by=dataset.created_by, | |||
| created_at=datetime.datetime.utcnow(), | |||
| used=True, | |||
| used_by=dataset.created_by, | |||
| used_at=datetime.datetime.utcnow() | |||
| ) | |||
| db.session.add(upload_file) | |||
| db.session.commit() | |||
| document_data = { | |||
| 'data_source': { | |||
| if not dataset: | |||
| raise ValueError('Dataset is not exist.') | |||
| if args['text']: | |||
| upload_file = FileService.upload_text(args.get('text'), args.get('name')) | |||
| data_source = { | |||
| 'type': 'upload_file', | |||
| 'info': [ | |||
| { | |||
| 'upload_file_id': upload_file.id | |||
| 'info_list': { | |||
| 'data_source_type': 'upload_file', | |||
| 'file_info_list': { | |||
| 'file_ids': [upload_file.id] | |||
| } | |||
| ] | |||
| } | |||
| } | |||
| args['data_source'] = data_source | |||
| # validate args | |||
| args['original_document_id'] = str(document_id) | |||
| DocumentService.document_create_args_validate(args) | |||
| try: | |||
| documents, batch = DocumentService.save_document_with_dataset_id( | |||
| dataset=dataset, | |||
| document_data=args, | |||
| account=current_user, | |||
| dataset_process_rule=dataset.latest_process_rule if 'process_rule' not in args else None, | |||
| created_from='api' | |||
| ) | |||
| except ProviderTokenNotInitError as ex: | |||
| raise ProviderNotInitializeError(ex.description) | |||
| document = documents[0] | |||
| documents_and_batch_fields = { | |||
| 'document': marshal(document, document_fields), | |||
| 'batch': batch | |||
| } | |||
| return documents_and_batch_fields, 200 | |||
| class DocumentAddByFileApi(DatasetApiResource): | |||
| """Resource for documents.""" | |||
| def post(self, tenant_id, dataset_id): | |||
| """Create document by upload file.""" | |||
| args = {} | |||
| if 'data' in request.form: | |||
| args = json.loads(request.form['data']) | |||
| if 'doc_form' not in args: | |||
| args['doc_form'] = 'text_model' | |||
| if 'doc_language' not in args: | |||
| args['doc_language'] = 'English' | |||
| # get dataset info | |||
| dataset_id = str(dataset_id) | |||
| tenant_id = str(tenant_id) | |||
| dataset = db.session.query(Dataset).filter( | |||
| Dataset.tenant_id == tenant_id, | |||
| Dataset.id == dataset_id | |||
| ).first() | |||
| if not dataset: | |||
| raise ValueError('Dataset is not exist.') | |||
| if not dataset.indexing_technique and not args['indexing_technique']: | |||
| raise ValueError('indexing_technique is required.') | |||
| # save file info | |||
| file = request.files['file'] | |||
| # check file | |||
| if 'file' not in request.files: | |||
| raise NoFileUploadedError() | |||
| if len(request.files) > 1: | |||
| raise TooManyFilesError() | |||
| upload_file = FileService.upload_file(file) | |||
| data_source = { | |||
| 'type': 'upload_file', | |||
| 'info_list': { | |||
| 'file_info_list': { | |||
| 'file_ids': [upload_file.id] | |||
| } | |||
| } | |||
| } | |||
| args['data_source'] = data_source | |||
| # validate args | |||
| DocumentService.document_create_args_validate(args) | |||
| try: | |||
| documents, batch = DocumentService.save_document_with_dataset_id( | |||
| dataset=dataset, | |||
| document_data=document_data, | |||
| document_data=args, | |||
| account=dataset.created_by_account, | |||
| dataset_process_rule=dataset.latest_process_rule, | |||
| dataset_process_rule=dataset.latest_process_rule if 'process_rule' not in args else None, | |||
| created_from='api' | |||
| ) | |||
| except ProviderTokenNotInitError as ex: | |||
| raise ProviderNotInitializeError(ex.description) | |||
| document = documents[0] | |||
| if doc_type and doc_metadata: | |||
| metadata_schema = DocumentService.DOCUMENT_METADATA_SCHEMA[doc_type] | |||
| documents_and_batch_fields = { | |||
| 'document': marshal(document, document_fields), | |||
| 'batch': batch | |||
| } | |||
| return documents_and_batch_fields, 200 | |||
| document.doc_metadata = {} | |||
| for key, value_type in metadata_schema.items(): | |||
| value = doc_metadata.get(key) | |||
| if value is not None and isinstance(value, value_type): | |||
| document.doc_metadata[key] = value | |||
| class DocumentUpdateByFileApi(DatasetApiResource): | |||
| """Resource for update documents.""" | |||
| document.doc_type = doc_type | |||
| document.updated_at = datetime.datetime.utcnow() | |||
| db.session.commit() | |||
| def post(self, tenant_id, dataset_id, document_id): | |||
| """Update document by upload file.""" | |||
| args = {} | |||
| if 'data' in request.form: | |||
| args = json.loads(request.form['data']) | |||
| if 'doc_form' not in args: | |||
| args['doc_form'] = 'text_model' | |||
| if 'doc_language' not in args: | |||
| args['doc_language'] = 'English' | |||
| return {'id': document.id} | |||
| # get dataset info | |||
| dataset_id = str(dataset_id) | |||
| tenant_id = str(tenant_id) | |||
| dataset = db.session.query(Dataset).filter( | |||
| Dataset.tenant_id == tenant_id, | |||
| Dataset.id == dataset_id | |||
| ).first() | |||
| if not dataset: | |||
| raise ValueError('Dataset is not exist.') | |||
| if 'file' in request.files: | |||
| # save file info | |||
| file = request.files['file'] | |||
| if len(request.files) > 1: | |||
| raise TooManyFilesError() | |||
| upload_file = FileService.upload_file(file) | |||
| data_source = { | |||
| 'type': 'upload_file', | |||
| 'info_list': { | |||
| 'file_info_list': { | |||
| 'file_ids': [upload_file.id] | |||
| } | |||
| } | |||
| } | |||
| args['data_source'] = data_source | |||
| # validate args | |||
| args['original_document_id'] = str(document_id) | |||
| DocumentService.document_create_args_validate(args) | |||
| try: | |||
| documents, batch = DocumentService.save_document_with_dataset_id( | |||
| dataset=dataset, | |||
| document_data=args, | |||
| account=dataset.created_by_account, | |||
| dataset_process_rule=dataset.latest_process_rule if 'process_rule' not in args else None, | |||
| created_from='api' | |||
| ) | |||
| except ProviderTokenNotInitError as ex: | |||
| raise ProviderNotInitializeError(ex.description) | |||
| document = documents[0] | |||
| documents_and_batch_fields = { | |||
| 'document': marshal(document, document_fields), | |||
| 'batch': batch | |||
| } | |||
| return documents_and_batch_fields, 200 | |||
| class DocumentApi(DatasetApiResource): | |||
| def delete(self, dataset, document_id): | |||
| class DocumentDeleteApi(DatasetApiResource): | |||
| def delete(self, tenant_id, dataset_id, document_id): | |||
| """Delete document.""" | |||
| document_id = str(document_id) | |||
| dataset_id = str(dataset_id) | |||
| tenant_id = str(tenant_id) | |||
| # get dataset info | |||
| dataset = db.session.query(Dataset).filter( | |||
| Dataset.tenant_id == tenant_id, | |||
| Dataset.id == dataset_id | |||
| ).first() | |||
| if not dataset: | |||
| raise ValueError('Dataset is not exist.') | |||
| document = DocumentService.get_document(dataset.id, document_id) | |||
| @@ -126,8 +303,85 @@ class DocumentApi(DatasetApiResource): | |||
| except services.errors.document.DocumentIndexingError: | |||
| raise DocumentIndexingError('Cannot delete document during indexing.') | |||
| return {'result': 'success'}, 204 | |||
| return {'result': 'success'}, 200 | |||
| class DocumentListApi(DatasetApiResource): | |||
| def get(self, tenant_id, dataset_id): | |||
| dataset_id = str(dataset_id) | |||
| tenant_id = str(tenant_id) | |||
| page = request.args.get('page', default=1, type=int) | |||
| limit = request.args.get('limit', default=20, type=int) | |||
| search = request.args.get('keyword', default=None, type=str) | |||
| dataset = db.session.query(Dataset).filter( | |||
| Dataset.tenant_id == tenant_id, | |||
| Dataset.id == dataset_id | |||
| ).first() | |||
| if not dataset: | |||
| raise NotFound('Dataset not found.') | |||
| query = Document.query.filter_by( | |||
| dataset_id=str(dataset_id), tenant_id=tenant_id) | |||
| if search: | |||
| search = f'%{search}%' | |||
| query = query.filter(Document.name.like(search)) | |||
| query = query.order_by(desc(Document.created_at)) | |||
| paginated_documents = query.paginate( | |||
| page=page, per_page=limit, max_per_page=100, error_out=False) | |||
| documents = paginated_documents.items | |||
| response = { | |||
| 'data': marshal(documents, document_fields), | |||
| 'has_more': len(documents) == limit, | |||
| 'limit': limit, | |||
| 'total': paginated_documents.total, | |||
| 'page': page | |||
| } | |||
| return response | |||
| class DocumentIndexingStatusApi(DatasetApiResource): | |||
| def get(self, tenant_id, dataset_id, batch): | |||
| dataset_id = str(dataset_id) | |||
| batch = str(batch) | |||
| tenant_id = str(tenant_id) | |||
| # get dataset | |||
| dataset = db.session.query(Dataset).filter( | |||
| Dataset.tenant_id == tenant_id, | |||
| Dataset.id == dataset_id | |||
| ).first() | |||
| if not dataset: | |||
| raise NotFound('Dataset not found.') | |||
| # get documents | |||
| documents = DocumentService.get_batch_documents(dataset_id, batch) | |||
| if not documents: | |||
| raise NotFound('Documents not found.') | |||
| documents_status = [] | |||
| for document in documents: | |||
| completed_segments = DocumentSegment.query.filter(DocumentSegment.completed_at.isnot(None), | |||
| DocumentSegment.document_id == str(document.id), | |||
| DocumentSegment.status != 're_segment').count() | |||
| total_segments = DocumentSegment.query.filter(DocumentSegment.document_id == str(document.id), | |||
| DocumentSegment.status != 're_segment').count() | |||
| document.completed_segments = completed_segments | |||
| document.total_segments = total_segments | |||
| if document.is_paused: | |||
| document.indexing_status = 'paused' | |||
| documents_status.append(marshal(document, document_status_fields)) | |||
| data = { | |||
| 'data': documents_status | |||
| } | |||
| return data | |||
| api.add_resource(DocumentListApi, '/documents') | |||
| api.add_resource(DocumentApi, '/documents/<uuid:document_id>') | |||
| api.add_resource(DocumentAddByTextApi, '/datasets/<uuid:dataset_id>/document/create_by_text') | |||
| api.add_resource(DocumentAddByFileApi, '/datasets/<uuid:dataset_id>/document/create_by_file') | |||
| api.add_resource(DocumentUpdateByTextApi, '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/update_by_text') | |||
| api.add_resource(DocumentUpdateByFileApi, '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/update_by_file') | |||
| api.add_resource(DocumentDeleteApi, '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>') | |||
| api.add_resource(DocumentListApi, '/datasets/<uuid:dataset_id>/documents') | |||
| api.add_resource(DocumentIndexingStatusApi, '/datasets/<uuid:dataset_id>/documents/<string:batch>/indexing-status') | |||
| @@ -1,20 +1,73 @@ | |||
| # -*- coding:utf-8 -*- | |||
| from libs.exception import BaseHTTPException | |||
| class NoFileUploadedError(BaseHTTPException): | |||
| error_code = 'no_file_uploaded' | |||
| description = "Please upload your file." | |||
| code = 400 | |||
| class TooManyFilesError(BaseHTTPException): | |||
| error_code = 'too_many_files' | |||
| description = "Only one file is allowed." | |||
| code = 400 | |||
| class FileTooLargeError(BaseHTTPException): | |||
| error_code = 'file_too_large' | |||
| description = "File size exceeded. {message}" | |||
| code = 413 | |||
| class UnsupportedFileTypeError(BaseHTTPException): | |||
| error_code = 'unsupported_file_type' | |||
| description = "File type not allowed." | |||
| code = 415 | |||
| class HighQualityDatasetOnlyError(BaseHTTPException): | |||
| error_code = 'high_quality_dataset_only' | |||
| description = "Current operation only supports 'high-quality' datasets." | |||
| code = 400 | |||
| class DatasetNotInitializedError(BaseHTTPException): | |||
| error_code = 'dataset_not_initialized' | |||
| description = "The dataset is still being initialized or indexing. Please wait a moment." | |||
| code = 400 | |||
| class ArchivedDocumentImmutableError(BaseHTTPException): | |||
| error_code = 'archived_document_immutable' | |||
| description = "Cannot operate when document was archived." | |||
| description = "The archived document is not editable." | |||
| code = 403 | |||
| class DatasetNameDuplicateError(BaseHTTPException): | |||
| error_code = 'dataset_name_duplicate' | |||
| description = "The dataset name already exists. Please modify your dataset name." | |||
| code = 409 | |||
| class InvalidActionError(BaseHTTPException): | |||
| error_code = 'invalid_action' | |||
| description = "Invalid action." | |||
| code = 400 | |||
| class DocumentAlreadyFinishedError(BaseHTTPException): | |||
| error_code = 'document_already_finished' | |||
| description = "The document has been processed. Please refresh the page or go to the document details." | |||
| code = 400 | |||
| class DocumentIndexingError(BaseHTTPException): | |||
| error_code = 'document_indexing' | |||
| description = "Cannot operate document during indexing." | |||
| code = 403 | |||
| description = "The document is being processed and cannot be edited." | |||
| code = 400 | |||
| class DatasetNotInitedError(BaseHTTPException): | |||
| error_code = 'dataset_not_inited' | |||
| description = "The dataset is still being initialized or indexing. Please wait a moment." | |||
| code = 403 | |||
| class InvalidMetadataError(BaseHTTPException): | |||
| error_code = 'invalid_metadata' | |||
| description = "The metadata content is incorrect. Please check and verify." | |||
| code = 400 | |||
| @@ -0,0 +1,59 @@ | |||
| from flask_login import current_user | |||
| from flask_restful import reqparse, marshal | |||
| from werkzeug.exceptions import NotFound | |||
| from controllers.service_api import api | |||
| from controllers.service_api.app.error import ProviderNotInitializeError | |||
| from controllers.service_api.wraps import DatasetApiResource | |||
| from core.model_providers.error import ProviderTokenNotInitError, LLMBadRequestError | |||
| from core.model_providers.model_factory import ModelFactory | |||
| from extensions.ext_database import db | |||
| from fields.segment_fields import segment_fields | |||
| from models.dataset import Dataset | |||
| from services.dataset_service import DocumentService, SegmentService | |||
| class SegmentApi(DatasetApiResource): | |||
| """Resource for segments.""" | |||
| def post(self, tenant_id, dataset_id, document_id): | |||
| """Create single segment.""" | |||
| # check dataset | |||
| dataset_id = str(dataset_id) | |||
| tenant_id = str(tenant_id) | |||
| dataset = db.session.query(Dataset).filter( | |||
| Dataset.tenant_id == tenant_id, | |||
| Dataset.id == dataset_id | |||
| ).first() | |||
| # check document | |||
| document_id = str(document_id) | |||
| document = DocumentService.get_document(dataset.id, document_id) | |||
| if not document: | |||
| raise NotFound('Document not found.') | |||
| # check embedding model setting | |||
| if dataset.indexing_technique == 'high_quality': | |||
| try: | |||
| ModelFactory.get_embedding_model( | |||
| tenant_id=current_user.current_tenant_id, | |||
| model_provider_name=dataset.embedding_model_provider, | |||
| model_name=dataset.embedding_model | |||
| ) | |||
| except LLMBadRequestError: | |||
| raise ProviderNotInitializeError( | |||
| f"No Embedding Model available. Please configure a valid provider " | |||
| f"in the Settings -> Model Provider.") | |||
| except ProviderTokenNotInitError as ex: | |||
| raise ProviderNotInitializeError(ex.description) | |||
| # validate args | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('segments', type=list, required=False, nullable=True, location='json') | |||
| args = parser.parse_args() | |||
| for args_item in args['segments']: | |||
| SegmentService.segment_create_args_validate(args_item, document) | |||
| segments = SegmentService.multi_create_segment(args['segments'], document, dataset) | |||
| return { | |||
| 'data': marshal(segments, segment_fields), | |||
| 'doc_form': document.doc_form | |||
| }, 200 | |||
| api.add_resource(SegmentApi, '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments') | |||
| @@ -2,11 +2,14 @@ | |||
| from datetime import datetime | |||
| from functools import wraps | |||
| from flask import request | |||
| from flask import request, current_app | |||
| from flask_login import user_logged_in | |||
| from flask_restful import Resource | |||
| from werkzeug.exceptions import NotFound, Unauthorized | |||
| from core.login.login import _get_user | |||
| from extensions.ext_database import db | |||
| from models.account import Tenant, TenantAccountJoin, Account | |||
| from models.dataset import Dataset | |||
| from models.model import ApiToken, App | |||
| @@ -43,12 +46,24 @@ def validate_dataset_token(view=None): | |||
| @wraps(view) | |||
| def decorated(*args, **kwargs): | |||
| api_token = validate_and_get_api_token('dataset') | |||
| dataset = db.session.query(Dataset).filter(Dataset.id == api_token.dataset_id).first() | |||
| if not dataset: | |||
| raise NotFound() | |||
| return view(dataset, *args, **kwargs) | |||
| tenant_account_join = db.session.query(Tenant, TenantAccountJoin) \ | |||
| .filter(Tenant.id == api_token.tenant_id) \ | |||
| .filter(TenantAccountJoin.tenant_id == Tenant.id) \ | |||
| .filter(TenantAccountJoin.role == 'owner') \ | |||
| .one_or_none() | |||
| if tenant_account_join: | |||
| tenant, ta = tenant_account_join | |||
| account = Account.query.filter_by(id=ta.account_id).first() | |||
| # Login admin | |||
| if account: | |||
| account.current_tenant = tenant | |||
| current_app.login_manager._update_request_context_with_user(account) | |||
| user_logged_in.send(current_app._get_current_object(), user=_get_user()) | |||
| else: | |||
| raise Unauthorized("Tenant owner account is not exist.") | |||
| else: | |||
| raise Unauthorized("Tenant is not exist.") | |||
| return view(api_token.tenant_id, *args, **kwargs) | |||
| return decorated | |||
| if view: | |||
| @@ -6,26 +6,12 @@ from werkzeug.exceptions import NotFound | |||
| from controllers.web import api | |||
| from controllers.web.error import NotChatAppError | |||
| from controllers.web.wraps import WebApiResource | |||
| from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields | |||
| from libs.helper import TimestampField, uuid_value | |||
| from services.conversation_service import ConversationService | |||
| from services.errors.conversation import LastConversationNotExistsError, ConversationNotExistsError | |||
| from services.web_conversation_service import WebConversationService | |||
| conversation_fields = { | |||
| 'id': fields.String, | |||
| 'name': fields.String, | |||
| 'inputs': fields.Raw, | |||
| 'status': fields.String, | |||
| 'introduction': fields.String, | |||
| 'created_at': TimestampField | |||
| } | |||
| conversation_infinite_scroll_pagination_fields = { | |||
| 'limit': fields.Integer, | |||
| 'has_more': fields.Boolean, | |||
| 'data': fields.List(fields.Nested(conversation_fields)) | |||
| } | |||
| class ConversationListApi(WebApiResource): | |||
| @@ -73,7 +59,7 @@ class ConversationApi(WebApiResource): | |||
| class ConversationRenameApi(WebApiResource): | |||
| @marshal_with(conversation_fields) | |||
| @marshal_with(simple_conversation_fields) | |||
| def post(self, app_model, end_user, c_id): | |||
| if app_model.mode != 'chat': | |||
| raise NotChatAppError() | |||
| @@ -16,6 +16,7 @@ logger = logging.getLogger(__name__) | |||
| BLOCK_CHILD_URL_TMPL = "https://api.notion.com/v1/blocks/{block_id}/children" | |||
| DATABASE_URL_TMPL = "https://api.notion.com/v1/databases/{database_id}/query" | |||
| SEARCH_URL = "https://api.notion.com/v1/search" | |||
| RETRIEVE_PAGE_URL_TMPL = "https://api.notion.com/v1/pages/{page_id}" | |||
| RETRIEVE_DATABASE_URL_TMPL = "https://api.notion.com/v1/databases/{database_id}" | |||
| HEADING_TYPE = ['heading_1', 'heading_2', 'heading_3'] | |||
| @@ -246,11 +246,28 @@ class KeywordTableIndex(BaseIndex): | |||
| keyword_table = self._add_text_to_keyword_table(keyword_table, node_id, keywords) | |||
| self._save_dataset_keyword_table(keyword_table) | |||
| def multi_create_segment_keywords(self, pre_segment_data_list: list): | |||
| keyword_table_handler = JiebaKeywordTableHandler() | |||
| keyword_table = self._get_dataset_keyword_table() | |||
| for pre_segment_data in pre_segment_data_list: | |||
| segment = pre_segment_data['segment'] | |||
| if pre_segment_data['keywords']: | |||
| segment.keywords = pre_segment_data['keywords'] | |||
| keyword_table = self._add_text_to_keyword_table(keyword_table, segment.index_node_id, | |||
| pre_segment_data['keywords']) | |||
| else: | |||
| keywords = keyword_table_handler.extract_keywords(segment.content, | |||
| self._config.max_keywords_per_chunk) | |||
| segment.keywords = list(keywords) | |||
| keyword_table = self._add_text_to_keyword_table(keyword_table, segment.index_node_id, list(keywords)) | |||
| self._save_dataset_keyword_table(keyword_table) | |||
| def update_segment_keywords_index(self, node_id: str, keywords: List[str]): | |||
| keyword_table = self._get_dataset_keyword_table() | |||
| keyword_table = self._add_text_to_keyword_table(keyword_table, node_id, keywords) | |||
| self._save_dataset_keyword_table(keyword_table) | |||
| class KeywordTableRetriever(BaseRetriever, BaseModel): | |||
| index: KeywordTableIndex | |||
| search_kwargs: dict = Field(default_factory=dict) | |||
| @@ -0,0 +1,138 @@ | |||
| from flask_restful import fields | |||
| from libs.helper import TimestampField | |||
| app_detail_kernel_fields = { | |||
| 'id': fields.String, | |||
| 'name': fields.String, | |||
| 'mode': fields.String, | |||
| 'icon': fields.String, | |||
| 'icon_background': fields.String, | |||
| } | |||
| related_app_list = { | |||
| 'data': fields.List(fields.Nested(app_detail_kernel_fields)), | |||
| 'total': fields.Integer, | |||
| } | |||
| model_config_fields = { | |||
| 'opening_statement': fields.String, | |||
| 'suggested_questions': fields.Raw(attribute='suggested_questions_list'), | |||
| 'suggested_questions_after_answer': fields.Raw(attribute='suggested_questions_after_answer_dict'), | |||
| 'speech_to_text': fields.Raw(attribute='speech_to_text_dict'), | |||
| 'retriever_resource': fields.Raw(attribute='retriever_resource_dict'), | |||
| 'more_like_this': fields.Raw(attribute='more_like_this_dict'), | |||
| 'sensitive_word_avoidance': fields.Raw(attribute='sensitive_word_avoidance_dict'), | |||
| 'model': fields.Raw(attribute='model_dict'), | |||
| 'user_input_form': fields.Raw(attribute='user_input_form_list'), | |||
| 'dataset_query_variable': fields.String, | |||
| 'pre_prompt': fields.String, | |||
| 'agent_mode': fields.Raw(attribute='agent_mode_dict'), | |||
| } | |||
| app_detail_fields = { | |||
| 'id': fields.String, | |||
| 'name': fields.String, | |||
| 'mode': fields.String, | |||
| 'icon': fields.String, | |||
| 'icon_background': fields.String, | |||
| 'enable_site': fields.Boolean, | |||
| 'enable_api': fields.Boolean, | |||
| 'api_rpm': fields.Integer, | |||
| 'api_rph': fields.Integer, | |||
| 'is_demo': fields.Boolean, | |||
| 'model_config': fields.Nested(model_config_fields, attribute='app_model_config'), | |||
| 'created_at': TimestampField | |||
| } | |||
| prompt_config_fields = { | |||
| 'prompt_template': fields.String, | |||
| } | |||
| model_config_partial_fields = { | |||
| 'model': fields.Raw(attribute='model_dict'), | |||
| 'pre_prompt': fields.String, | |||
| } | |||
| app_partial_fields = { | |||
| 'id': fields.String, | |||
| 'name': fields.String, | |||
| 'mode': fields.String, | |||
| 'icon': fields.String, | |||
| 'icon_background': fields.String, | |||
| 'enable_site': fields.Boolean, | |||
| 'enable_api': fields.Boolean, | |||
| 'is_demo': fields.Boolean, | |||
| 'model_config': fields.Nested(model_config_partial_fields, attribute='app_model_config'), | |||
| 'created_at': TimestampField | |||
| } | |||
| app_pagination_fields = { | |||
| 'page': fields.Integer, | |||
| 'limit': fields.Integer(attribute='per_page'), | |||
| 'total': fields.Integer, | |||
| 'has_more': fields.Boolean(attribute='has_next'), | |||
| 'data': fields.List(fields.Nested(app_partial_fields), attribute='items') | |||
| } | |||
| template_fields = { | |||
| 'name': fields.String, | |||
| 'icon': fields.String, | |||
| 'icon_background': fields.String, | |||
| 'description': fields.String, | |||
| 'mode': fields.String, | |||
| 'model_config': fields.Nested(model_config_fields), | |||
| } | |||
| template_list_fields = { | |||
| 'data': fields.List(fields.Nested(template_fields)), | |||
| } | |||
| site_fields = { | |||
| 'access_token': fields.String(attribute='code'), | |||
| 'code': fields.String, | |||
| 'title': fields.String, | |||
| 'icon': fields.String, | |||
| 'icon_background': fields.String, | |||
| 'description': fields.String, | |||
| 'default_language': fields.String, | |||
| 'customize_domain': fields.String, | |||
| 'copyright': fields.String, | |||
| 'privacy_policy': fields.String, | |||
| 'customize_token_strategy': fields.String, | |||
| 'prompt_public': fields.Boolean, | |||
| 'app_base_url': fields.String, | |||
| } | |||
| app_detail_fields_with_site = { | |||
| 'id': fields.String, | |||
| 'name': fields.String, | |||
| 'mode': fields.String, | |||
| 'icon': fields.String, | |||
| 'icon_background': fields.String, | |||
| 'enable_site': fields.Boolean, | |||
| 'enable_api': fields.Boolean, | |||
| 'api_rpm': fields.Integer, | |||
| 'api_rph': fields.Integer, | |||
| 'is_demo': fields.Boolean, | |||
| 'model_config': fields.Nested(model_config_fields, attribute='app_model_config'), | |||
| 'site': fields.Nested(site_fields), | |||
| 'api_base_url': fields.String, | |||
| 'created_at': TimestampField | |||
| } | |||
| app_site_fields = { | |||
| 'app_id': fields.String, | |||
| 'access_token': fields.String(attribute='code'), | |||
| 'code': fields.String, | |||
| 'title': fields.String, | |||
| 'icon': fields.String, | |||
| 'icon_background': fields.String, | |||
| 'description': fields.String, | |||
| 'default_language': fields.String, | |||
| 'customize_domain': fields.String, | |||
| 'copyright': fields.String, | |||
| 'privacy_policy': fields.String, | |||
| 'customize_token_strategy': fields.String, | |||
| 'prompt_public': fields.Boolean | |||
| } | |||
| @@ -0,0 +1,182 @@ | |||
| from flask_restful import fields | |||
| from libs.helper import TimestampField | |||
| class MessageTextField(fields.Raw): | |||
| def format(self, value): | |||
| return value[0]['text'] if value else '' | |||
| account_fields = { | |||
| 'id': fields.String, | |||
| 'name': fields.String, | |||
| 'email': fields.String | |||
| } | |||
| feedback_fields = { | |||
| 'rating': fields.String, | |||
| 'content': fields.String, | |||
| 'from_source': fields.String, | |||
| 'from_end_user_id': fields.String, | |||
| 'from_account': fields.Nested(account_fields, allow_null=True), | |||
| } | |||
| annotation_fields = { | |||
| 'content': fields.String, | |||
| 'account': fields.Nested(account_fields, allow_null=True), | |||
| 'created_at': TimestampField | |||
| } | |||
| message_detail_fields = { | |||
| 'id': fields.String, | |||
| 'conversation_id': fields.String, | |||
| 'inputs': fields.Raw, | |||
| 'query': fields.String, | |||
| 'message': fields.Raw, | |||
| 'message_tokens': fields.Integer, | |||
| 'answer': fields.String, | |||
| 'answer_tokens': fields.Integer, | |||
| 'provider_response_latency': fields.Float, | |||
| 'from_source': fields.String, | |||
| 'from_end_user_id': fields.String, | |||
| 'from_account_id': fields.String, | |||
| 'feedbacks': fields.List(fields.Nested(feedback_fields)), | |||
| 'annotation': fields.Nested(annotation_fields, allow_null=True), | |||
| 'created_at': TimestampField | |||
| } | |||
| feedback_stat_fields = { | |||
| 'like': fields.Integer, | |||
| 'dislike': fields.Integer | |||
| } | |||
| model_config_fields = { | |||
| 'opening_statement': fields.String, | |||
| 'suggested_questions': fields.Raw, | |||
| 'model': fields.Raw, | |||
| 'user_input_form': fields.Raw, | |||
| 'pre_prompt': fields.String, | |||
| 'agent_mode': fields.Raw, | |||
| } | |||
| simple_configs_fields = { | |||
| 'prompt_template': fields.String, | |||
| } | |||
| simple_model_config_fields = { | |||
| 'model': fields.Raw(attribute='model_dict'), | |||
| 'pre_prompt': fields.String, | |||
| } | |||
| simple_message_detail_fields = { | |||
| 'inputs': fields.Raw, | |||
| 'query': fields.String, | |||
| 'message': MessageTextField, | |||
| 'answer': fields.String, | |||
| } | |||
| conversation_fields = { | |||
| 'id': fields.String, | |||
| 'status': fields.String, | |||
| 'from_source': fields.String, | |||
| 'from_end_user_id': fields.String, | |||
| 'from_end_user_session_id': fields.String(), | |||
| 'from_account_id': fields.String, | |||
| 'read_at': TimestampField, | |||
| 'created_at': TimestampField, | |||
| 'annotation': fields.Nested(annotation_fields, allow_null=True), | |||
| 'model_config': fields.Nested(simple_model_config_fields), | |||
| 'user_feedback_stats': fields.Nested(feedback_stat_fields), | |||
| 'admin_feedback_stats': fields.Nested(feedback_stat_fields), | |||
| 'message': fields.Nested(simple_message_detail_fields, attribute='first_message') | |||
| } | |||
| conversation_pagination_fields = { | |||
| 'page': fields.Integer, | |||
| 'limit': fields.Integer(attribute='per_page'), | |||
| 'total': fields.Integer, | |||
| 'has_more': fields.Boolean(attribute='has_next'), | |||
| 'data': fields.List(fields.Nested(conversation_fields), attribute='items') | |||
| } | |||
| conversation_message_detail_fields = { | |||
| 'id': fields.String, | |||
| 'status': fields.String, | |||
| 'from_source': fields.String, | |||
| 'from_end_user_id': fields.String, | |||
| 'from_account_id': fields.String, | |||
| 'created_at': TimestampField, | |||
| 'model_config': fields.Nested(model_config_fields), | |||
| 'message': fields.Nested(message_detail_fields, attribute='first_message'), | |||
| } | |||
| simple_model_config_fields = { | |||
| 'model': fields.Raw(attribute='model_dict'), | |||
| 'pre_prompt': fields.String, | |||
| } | |||
| conversation_with_summary_fields = { | |||
| 'id': fields.String, | |||
| 'status': fields.String, | |||
| 'from_source': fields.String, | |||
| 'from_end_user_id': fields.String, | |||
| 'from_end_user_session_id': fields.String, | |||
| 'from_account_id': fields.String, | |||
| 'summary': fields.String(attribute='summary_or_query'), | |||
| 'read_at': TimestampField, | |||
| 'created_at': TimestampField, | |||
| 'annotated': fields.Boolean, | |||
| 'model_config': fields.Nested(simple_model_config_fields), | |||
| 'message_count': fields.Integer, | |||
| 'user_feedback_stats': fields.Nested(feedback_stat_fields), | |||
| 'admin_feedback_stats': fields.Nested(feedback_stat_fields) | |||
| } | |||
| conversation_with_summary_pagination_fields = { | |||
| 'page': fields.Integer, | |||
| 'limit': fields.Integer(attribute='per_page'), | |||
| 'total': fields.Integer, | |||
| 'has_more': fields.Boolean(attribute='has_next'), | |||
| 'data': fields.List(fields.Nested(conversation_with_summary_fields), attribute='items') | |||
| } | |||
| conversation_detail_fields = { | |||
| 'id': fields.String, | |||
| 'status': fields.String, | |||
| 'from_source': fields.String, | |||
| 'from_end_user_id': fields.String, | |||
| 'from_account_id': fields.String, | |||
| 'created_at': TimestampField, | |||
| 'annotated': fields.Boolean, | |||
| 'model_config': fields.Nested(model_config_fields), | |||
| 'message_count': fields.Integer, | |||
| 'user_feedback_stats': fields.Nested(feedback_stat_fields), | |||
| 'admin_feedback_stats': fields.Nested(feedback_stat_fields) | |||
| } | |||
| simple_conversation_fields = { | |||
| 'id': fields.String, | |||
| 'name': fields.String, | |||
| 'inputs': fields.Raw, | |||
| 'status': fields.String, | |||
| 'introduction': fields.String, | |||
| 'created_at': TimestampField | |||
| } | |||
| conversation_infinite_scroll_pagination_fields = { | |||
| 'limit': fields.Integer, | |||
| 'has_more': fields.Boolean, | |||
| 'data': fields.List(fields.Nested(simple_conversation_fields)) | |||
| } | |||
| conversation_with_model_config_fields = { | |||
| **simple_conversation_fields, | |||
| 'model_config': fields.Raw, | |||
| } | |||
| conversation_with_model_config_infinite_scroll_pagination_fields = { | |||
| 'limit': fields.Integer, | |||
| 'has_more': fields.Boolean, | |||
| 'data': fields.List(fields.Nested(conversation_with_model_config_fields)) | |||
| } | |||
| @@ -0,0 +1,65 @@ | |||
| from flask_restful import fields | |||
| from libs.helper import TimestampField | |||
| integrate_icon_fields = { | |||
| 'type': fields.String, | |||
| 'url': fields.String, | |||
| 'emoji': fields.String | |||
| } | |||
| integrate_page_fields = { | |||
| 'page_name': fields.String, | |||
| 'page_id': fields.String, | |||
| 'page_icon': fields.Nested(integrate_icon_fields, allow_null=True), | |||
| 'is_bound': fields.Boolean, | |||
| 'parent_id': fields.String, | |||
| 'type': fields.String | |||
| } | |||
| integrate_workspace_fields = { | |||
| 'workspace_name': fields.String, | |||
| 'workspace_id': fields.String, | |||
| 'workspace_icon': fields.String, | |||
| 'pages': fields.List(fields.Nested(integrate_page_fields)) | |||
| } | |||
| integrate_notion_info_list_fields = { | |||
| 'notion_info': fields.List(fields.Nested(integrate_workspace_fields)), | |||
| } | |||
| integrate_icon_fields = { | |||
| 'type': fields.String, | |||
| 'url': fields.String, | |||
| 'emoji': fields.String | |||
| } | |||
| integrate_page_fields = { | |||
| 'page_name': fields.String, | |||
| 'page_id': fields.String, | |||
| 'page_icon': fields.Nested(integrate_icon_fields, allow_null=True), | |||
| 'parent_id': fields.String, | |||
| 'type': fields.String | |||
| } | |||
| integrate_workspace_fields = { | |||
| 'workspace_name': fields.String, | |||
| 'workspace_id': fields.String, | |||
| 'workspace_icon': fields.String, | |||
| 'pages': fields.List(fields.Nested(integrate_page_fields)), | |||
| 'total': fields.Integer | |||
| } | |||
| integrate_fields = { | |||
| 'id': fields.String, | |||
| 'provider': fields.String, | |||
| 'created_at': TimestampField, | |||
| 'is_bound': fields.Boolean, | |||
| 'disabled': fields.Boolean, | |||
| 'link': fields.String, | |||
| 'source_info': fields.Nested(integrate_workspace_fields) | |||
| } | |||
| integrate_list_fields = { | |||
| 'data': fields.List(fields.Nested(integrate_fields)), | |||
| } | |||
| @@ -0,0 +1,43 @@ | |||
| from flask_restful import fields | |||
| from libs.helper import TimestampField | |||
| dataset_fields = { | |||
| 'id': fields.String, | |||
| 'name': fields.String, | |||
| 'description': fields.String, | |||
| 'permission': fields.String, | |||
| 'data_source_type': fields.String, | |||
| 'indexing_technique': fields.String, | |||
| 'created_by': fields.String, | |||
| 'created_at': TimestampField, | |||
| } | |||
| dataset_detail_fields = { | |||
| 'id': fields.String, | |||
| 'name': fields.String, | |||
| 'description': fields.String, | |||
| 'provider': fields.String, | |||
| 'permission': fields.String, | |||
| 'data_source_type': fields.String, | |||
| 'indexing_technique': fields.String, | |||
| 'app_count': fields.Integer, | |||
| 'document_count': fields.Integer, | |||
| 'word_count': fields.Integer, | |||
| 'created_by': fields.String, | |||
| 'created_at': TimestampField, | |||
| 'updated_by': fields.String, | |||
| 'updated_at': TimestampField, | |||
| 'embedding_model': fields.String, | |||
| 'embedding_model_provider': fields.String, | |||
| 'embedding_available': fields.Boolean | |||
| } | |||
| dataset_query_detail_fields = { | |||
| "id": fields.String, | |||
| "content": fields.String, | |||
| "source": fields.String, | |||
| "source_app_id": fields.String, | |||
| "created_by_role": fields.String, | |||
| "created_by": fields.String, | |||
| "created_at": TimestampField | |||
| } | |||
| @@ -0,0 +1,76 @@ | |||
| from flask_restful import fields | |||
| from fields.dataset_fields import dataset_fields | |||
| from libs.helper import TimestampField | |||
| document_fields = { | |||
| 'id': fields.String, | |||
| 'position': fields.Integer, | |||
| 'data_source_type': fields.String, | |||
| 'data_source_info': fields.Raw(attribute='data_source_info_dict'), | |||
| 'dataset_process_rule_id': fields.String, | |||
| 'name': fields.String, | |||
| 'created_from': fields.String, | |||
| 'created_by': fields.String, | |||
| 'created_at': TimestampField, | |||
| 'tokens': fields.Integer, | |||
| 'indexing_status': fields.String, | |||
| 'error': fields.String, | |||
| 'enabled': fields.Boolean, | |||
| 'disabled_at': TimestampField, | |||
| 'disabled_by': fields.String, | |||
| 'archived': fields.Boolean, | |||
| 'display_status': fields.String, | |||
| 'word_count': fields.Integer, | |||
| 'hit_count': fields.Integer, | |||
| 'doc_form': fields.String, | |||
| } | |||
| document_with_segments_fields = { | |||
| 'id': fields.String, | |||
| 'position': fields.Integer, | |||
| 'data_source_type': fields.String, | |||
| 'data_source_info': fields.Raw(attribute='data_source_info_dict'), | |||
| 'dataset_process_rule_id': fields.String, | |||
| 'name': fields.String, | |||
| 'created_from': fields.String, | |||
| 'created_by': fields.String, | |||
| 'created_at': TimestampField, | |||
| 'tokens': fields.Integer, | |||
| 'indexing_status': fields.String, | |||
| 'error': fields.String, | |||
| 'enabled': fields.Boolean, | |||
| 'disabled_at': TimestampField, | |||
| 'disabled_by': fields.String, | |||
| 'archived': fields.Boolean, | |||
| 'display_status': fields.String, | |||
| 'word_count': fields.Integer, | |||
| 'hit_count': fields.Integer, | |||
| 'completed_segments': fields.Integer, | |||
| 'total_segments': fields.Integer | |||
| } | |||
| dataset_and_document_fields = { | |||
| 'dataset': fields.Nested(dataset_fields), | |||
| 'documents': fields.List(fields.Nested(document_fields)), | |||
| 'batch': fields.String | |||
| } | |||
| document_status_fields = { | |||
| 'id': fields.String, | |||
| 'indexing_status': fields.String, | |||
| 'processing_started_at': TimestampField, | |||
| 'parsing_completed_at': TimestampField, | |||
| 'cleaning_completed_at': TimestampField, | |||
| 'splitting_completed_at': TimestampField, | |||
| 'completed_at': TimestampField, | |||
| 'paused_at': TimestampField, | |||
| 'error': fields.String, | |||
| 'stopped_at': TimestampField, | |||
| 'completed_segments': fields.Integer, | |||
| 'total_segments': fields.Integer, | |||
| } | |||
| document_status_fields_list = { | |||
| 'data': fields.List(fields.Nested(document_status_fields)) | |||
| } | |||
| @@ -0,0 +1,18 @@ | |||
| from flask_restful import fields | |||
| from libs.helper import TimestampField | |||
| upload_config_fields = { | |||
| 'file_size_limit': fields.Integer, | |||
| 'batch_count_limit': fields.Integer | |||
| } | |||
| file_fields = { | |||
| 'id': fields.String, | |||
| 'name': fields.String, | |||
| 'size': fields.Integer, | |||
| 'extension': fields.String, | |||
| 'mime_type': fields.String, | |||
| 'created_by': fields.String, | |||
| 'created_at': TimestampField, | |||
| } | |||
| @@ -0,0 +1,41 @@ | |||
| from flask_restful import fields | |||
| from libs.helper import TimestampField | |||
| document_fields = { | |||
| 'id': fields.String, | |||
| 'data_source_type': fields.String, | |||
| 'name': fields.String, | |||
| 'doc_type': fields.String, | |||
| } | |||
| segment_fields = { | |||
| 'id': fields.String, | |||
| 'position': fields.Integer, | |||
| 'document_id': fields.String, | |||
| 'content': fields.String, | |||
| 'answer': fields.String, | |||
| 'word_count': fields.Integer, | |||
| 'tokens': fields.Integer, | |||
| 'keywords': fields.List(fields.String), | |||
| 'index_node_id': fields.String, | |||
| 'index_node_hash': fields.String, | |||
| 'hit_count': fields.Integer, | |||
| 'enabled': fields.Boolean, | |||
| 'disabled_at': TimestampField, | |||
| 'disabled_by': fields.String, | |||
| 'status': fields.String, | |||
| 'created_by': fields.String, | |||
| 'created_at': TimestampField, | |||
| 'indexing_at': TimestampField, | |||
| 'completed_at': TimestampField, | |||
| 'error': fields.String, | |||
| 'stopped_at': TimestampField, | |||
| 'document': fields.Nested(document_fields), | |||
| } | |||
| hit_testing_record_fields = { | |||
| 'segment': fields.Nested(segment_fields), | |||
| 'score': fields.Float, | |||
| 'tsne_position': fields.Raw | |||
| } | |||
| @@ -0,0 +1,25 @@ | |||
| from flask_restful import fields | |||
| from libs.helper import TimestampField | |||
| app_fields = { | |||
| 'id': fields.String, | |||
| 'name': fields.String, | |||
| 'mode': fields.String, | |||
| 'icon': fields.String, | |||
| 'icon_background': fields.String | |||
| } | |||
| installed_app_fields = { | |||
| 'id': fields.String, | |||
| 'app': fields.Nested(app_fields), | |||
| 'app_owner_tenant_id': fields.String, | |||
| 'is_pinned': fields.Boolean, | |||
| 'last_used_at': TimestampField, | |||
| 'editable': fields.Boolean, | |||
| 'uninstallable': fields.Boolean, | |||
| } | |||
| installed_app_list_fields = { | |||
| 'installed_apps': fields.List(fields.Nested(installed_app_fields)) | |||
| } | |||
| @@ -0,0 +1,43 @@ | |||
| from flask_restful import fields | |||
| from libs.helper import TimestampField | |||
| feedback_fields = { | |||
| 'rating': fields.String | |||
| } | |||
| retriever_resource_fields = { | |||
| 'id': fields.String, | |||
| 'message_id': fields.String, | |||
| 'position': fields.Integer, | |||
| 'dataset_id': fields.String, | |||
| 'dataset_name': fields.String, | |||
| 'document_id': fields.String, | |||
| 'document_name': fields.String, | |||
| 'data_source_type': fields.String, | |||
| 'segment_id': fields.String, | |||
| 'score': fields.Float, | |||
| 'hit_count': fields.Integer, | |||
| 'word_count': fields.Integer, | |||
| 'segment_position': fields.Integer, | |||
| 'index_node_hash': fields.String, | |||
| 'content': fields.String, | |||
| 'created_at': TimestampField | |||
| } | |||
| message_fields = { | |||
| 'id': fields.String, | |||
| 'conversation_id': fields.String, | |||
| 'inputs': fields.Raw, | |||
| 'query': fields.String, | |||
| 'answer': fields.String, | |||
| 'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True), | |||
| 'retriever_resources': fields.List(fields.Nested(retriever_resource_fields)), | |||
| 'created_at': TimestampField | |||
| } | |||
| message_infinite_scroll_pagination_fields = { | |||
| 'limit': fields.Integer, | |||
| 'has_more': fields.Boolean, | |||
| 'data': fields.List(fields.Nested(message_fields)) | |||
| } | |||
| @@ -0,0 +1,32 @@ | |||
| from flask_restful import fields | |||
| from libs.helper import TimestampField | |||
| segment_fields = { | |||
| 'id': fields.String, | |||
| 'position': fields.Integer, | |||
| 'document_id': fields.String, | |||
| 'content': fields.String, | |||
| 'answer': fields.String, | |||
| 'word_count': fields.Integer, | |||
| 'tokens': fields.Integer, | |||
| 'keywords': fields.List(fields.String), | |||
| 'index_node_id': fields.String, | |||
| 'index_node_hash': fields.String, | |||
| 'hit_count': fields.Integer, | |||
| 'enabled': fields.Boolean, | |||
| 'disabled_at': TimestampField, | |||
| 'disabled_by': fields.String, | |||
| 'status': fields.String, | |||
| 'created_by': fields.String, | |||
| 'created_at': TimestampField, | |||
| 'indexing_at': TimestampField, | |||
| 'completed_at': TimestampField, | |||
| 'error': fields.String, | |||
| 'stopped_at': TimestampField | |||
| } | |||
| segment_list_response = { | |||
| 'data': fields.List(fields.Nested(segment_fields)), | |||
| 'has_more': fields.Boolean, | |||
| 'limit': fields.Integer | |||
| } | |||
| @@ -0,0 +1,36 @@ | |||
| """add_tenant_id_in_api_token | |||
| Revision ID: 2e9819ca5b28 | |||
| Revises: 6e2cfb077b04 | |||
| Create Date: 2023-09-22 15:41:01.243183 | |||
| """ | |||
| from alembic import op | |||
| import sqlalchemy as sa | |||
| from sqlalchemy.dialects import postgresql | |||
| # revision identifiers, used by Alembic. | |||
| revision = '2e9819ca5b28' | |||
| down_revision = 'ab23c11305d4' | |||
| branch_labels = None | |||
| depends_on = None | |||
| def upgrade(): | |||
| # ### commands auto generated by Alembic - please adjust! ### | |||
| with op.batch_alter_table('api_tokens', schema=None) as batch_op: | |||
| batch_op.add_column(sa.Column('tenant_id', postgresql.UUID(), nullable=True)) | |||
| batch_op.create_index('api_token_tenant_idx', ['tenant_id', 'type'], unique=False) | |||
| batch_op.drop_column('dataset_id') | |||
| # ### end Alembic commands ### | |||
| def downgrade(): | |||
| # ### commands auto generated by Alembic - please adjust! ### | |||
| with op.batch_alter_table('api_tokens', schema=None) as batch_op: | |||
| batch_op.add_column(sa.Column('dataset_id', postgresql.UUID(), autoincrement=False, nullable=True)) | |||
| batch_op.drop_index('api_token_tenant_idx') | |||
| batch_op.drop_column('tenant_id') | |||
| # ### end Alembic commands ### | |||
| @@ -629,12 +629,13 @@ class ApiToken(db.Model): | |||
| __table_args__ = ( | |||
| db.PrimaryKeyConstraint('id', name='api_token_pkey'), | |||
| db.Index('api_token_app_id_type_idx', 'app_id', 'type'), | |||
| db.Index('api_token_token_idx', 'token', 'type') | |||
| db.Index('api_token_token_idx', 'token', 'type'), | |||
| db.Index('api_token_tenant_idx', 'tenant_id', 'type') | |||
| ) | |||
| id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) | |||
| app_id = db.Column(UUID, nullable=True) | |||
| dataset_id = db.Column(UUID, nullable=True) | |||
| tenant_id = db.Column(UUID, nullable=True) | |||
| type = db.Column(db.String(16), nullable=False) | |||
| token = db.Column(db.String(255), nullable=False) | |||
| last_used_at = db.Column(db.DateTime, nullable=True) | |||
| @@ -96,7 +96,7 @@ class DatasetService: | |||
| embedding_model = None | |||
| if indexing_technique == 'high_quality': | |||
| embedding_model = ModelFactory.get_embedding_model( | |||
| tenant_id=current_user.current_tenant_id | |||
| tenant_id=tenant_id | |||
| ) | |||
| dataset = Dataset(name=name, indexing_technique=indexing_technique) | |||
| # dataset = Dataset(name=name, provider=provider, config=config) | |||
| @@ -477,6 +477,7 @@ class DocumentService: | |||
| ) | |||
| dataset.collection_binding_id = dataset_collection_binding.id | |||
| documents = [] | |||
| batch = time.strftime('%Y%m%d%H%M%S') + str(random.randint(100000, 999999)) | |||
| if 'original_document_id' in document_data and document_data["original_document_id"]: | |||
| @@ -626,6 +627,9 @@ class DocumentService: | |||
| document = DocumentService.get_document(dataset.id, document_data["original_document_id"]) | |||
| if document.display_status != 'available': | |||
| raise ValueError("Document is not available") | |||
| # update document name | |||
| if 'name' in document_data and document_data['name']: | |||
| document.name = document_data['name'] | |||
| # save process rule | |||
| if 'process_rule' in document_data and document_data['process_rule']: | |||
| process_rule = document_data["process_rule"] | |||
| @@ -767,7 +771,7 @@ class DocumentService: | |||
| return dataset, documents, batch | |||
| @classmethod | |||
| def document_create_args_validate(cls, args: dict): | |||
| def document_create_args_validate(cls, args: dict): | |||
| if 'original_document_id' not in args or not args['original_document_id']: | |||
| DocumentService.data_source_args_validate(args) | |||
| DocumentService.process_rule_args_validate(args) | |||
| @@ -1014,6 +1018,66 @@ class SegmentService: | |||
| segment = db.session.query(DocumentSegment).filter(DocumentSegment.id == segment_document.id).first() | |||
| return segment | |||
| @classmethod | |||
| def multi_create_segment(cls, segments: list, document: Document, dataset: Dataset): | |||
| embedding_model = None | |||
| if dataset.indexing_technique == 'high_quality': | |||
| embedding_model = ModelFactory.get_embedding_model( | |||
| tenant_id=dataset.tenant_id, | |||
| model_provider_name=dataset.embedding_model_provider, | |||
| model_name=dataset.embedding_model | |||
| ) | |||
| max_position = db.session.query(func.max(DocumentSegment.position)).filter( | |||
| DocumentSegment.document_id == document.id | |||
| ).scalar() | |||
| pre_segment_data_list = [] | |||
| segment_data_list = [] | |||
| for segment_item in segments: | |||
| content = segment_item['content'] | |||
| doc_id = str(uuid.uuid4()) | |||
| segment_hash = helper.generate_text_hash(content) | |||
| tokens = 0 | |||
| if dataset.indexing_technique == 'high_quality' and embedding_model: | |||
| # calc embedding use tokens | |||
| tokens = embedding_model.get_num_tokens(content) | |||
| segment_document = DocumentSegment( | |||
| tenant_id=current_user.current_tenant_id, | |||
| dataset_id=document.dataset_id, | |||
| document_id=document.id, | |||
| index_node_id=doc_id, | |||
| index_node_hash=segment_hash, | |||
| position=max_position + 1 if max_position else 1, | |||
| content=content, | |||
| word_count=len(content), | |||
| tokens=tokens, | |||
| status='completed', | |||
| indexing_at=datetime.datetime.utcnow(), | |||
| completed_at=datetime.datetime.utcnow(), | |||
| created_by=current_user.id | |||
| ) | |||
| if document.doc_form == 'qa_model': | |||
| segment_document.answer = segment_item['answer'] | |||
| db.session.add(segment_document) | |||
| segment_data_list.append(segment_document) | |||
| pre_segment_data = { | |||
| 'segment': segment_document, | |||
| 'keywords': segment_item['keywords'] | |||
| } | |||
| pre_segment_data_list.append(pre_segment_data) | |||
| try: | |||
| # save vector index | |||
| VectorService.multi_create_segment_vector(pre_segment_data_list, dataset) | |||
| except Exception as e: | |||
| logging.exception("create segment index failed") | |||
| for segment_document in segment_data_list: | |||
| segment_document.enabled = False | |||
| segment_document.disabled_at = datetime.datetime.utcnow() | |||
| segment_document.status = 'error' | |||
| segment_document.error = str(e) | |||
| db.session.commit() | |||
| return segment_data_list | |||
| @classmethod | |||
| def update_segment(cls, args: dict, segment: DocumentSegment, document: Document, dataset: Dataset): | |||
| indexing_cache_key = 'segment_{}_indexing'.format(segment.id) | |||
| @@ -1,7 +1,7 @@ | |||
| # -*- coding:utf-8 -*- | |||
| __all__ = [ | |||
| 'base', 'conversation', 'message', 'index', 'app_model_config', 'account', 'document', 'dataset', | |||
| 'app', 'completion', 'audio' | |||
| 'app', 'completion', 'audio', 'file' | |||
| ] | |||
| from . import * | |||
| @@ -3,3 +3,11 @@ from services.errors.base import BaseServiceError | |||
| class FileNotExistsError(BaseServiceError): | |||
| pass | |||
| class FileTooLargeError(BaseServiceError): | |||
| description = "{message}" | |||
| class UnsupportedFileTypeError(BaseServiceError): | |||
| pass | |||
| @@ -0,0 +1,123 @@ | |||
| import datetime | |||
| import hashlib | |||
| import time | |||
| import uuid | |||
| from cachetools import TTLCache | |||
| from flask import request, current_app | |||
| from flask_login import current_user | |||
| from werkzeug.datastructures import FileStorage | |||
| from werkzeug.exceptions import NotFound | |||
| from core.data_loader.file_extractor import FileExtractor | |||
| from extensions.ext_storage import storage | |||
| from extensions.ext_database import db | |||
| from models.model import UploadFile | |||
| from services.errors.file import FileTooLargeError, UnsupportedFileTypeError | |||
| ALLOWED_EXTENSIONS = ['txt', 'markdown', 'md', 'pdf', 'html', 'htm', 'xlsx', 'docx', 'csv'] | |||
| PREVIEW_WORDS_LIMIT = 3000 | |||
| cache = TTLCache(maxsize=None, ttl=30) | |||
| class FileService: | |||
| @staticmethod | |||
| def upload_file(file: FileStorage) -> UploadFile: | |||
| # read file content | |||
| file_content = file.read() | |||
| # get file size | |||
| file_size = len(file_content) | |||
| file_size_limit = current_app.config.get("UPLOAD_FILE_SIZE_LIMIT") * 1024 * 1024 | |||
| if file_size > file_size_limit: | |||
| message = f'File size exceeded. {file_size} > {file_size_limit}' | |||
| raise FileTooLargeError(message) | |||
| extension = file.filename.split('.')[-1] | |||
| if extension.lower() not in ALLOWED_EXTENSIONS: | |||
| raise UnsupportedFileTypeError() | |||
| # user uuid as file name | |||
| file_uuid = str(uuid.uuid4()) | |||
| file_key = 'upload_files/' + current_user.current_tenant_id + '/' + file_uuid + '.' + extension | |||
| # save file to storage | |||
| storage.save(file_key, file_content) | |||
| # save file to db | |||
| config = current_app.config | |||
| upload_file = UploadFile( | |||
| tenant_id=current_user.current_tenant_id, | |||
| storage_type=config['STORAGE_TYPE'], | |||
| key=file_key, | |||
| name=file.filename, | |||
| size=file_size, | |||
| extension=extension, | |||
| mime_type=file.mimetype, | |||
| created_by=current_user.id, | |||
| created_at=datetime.datetime.utcnow(), | |||
| used=False, | |||
| hash=hashlib.sha3_256(file_content).hexdigest() | |||
| ) | |||
| db.session.add(upload_file) | |||
| db.session.commit() | |||
| return upload_file | |||
| @staticmethod | |||
| def upload_text(text: str, text_name: str) -> UploadFile: | |||
| # user uuid as file name | |||
| file_uuid = str(uuid.uuid4()) | |||
| file_key = 'upload_files/' + current_user.current_tenant_id + '/' + file_uuid + '.txt' | |||
| # save file to storage | |||
| storage.save(file_key, text.encode('utf-8')) | |||
| # save file to db | |||
| config = current_app.config | |||
| upload_file = UploadFile( | |||
| tenant_id=current_user.current_tenant_id, | |||
| storage_type=config['STORAGE_TYPE'], | |||
| key=file_key, | |||
| name=text_name + '.txt', | |||
| size=len(text), | |||
| extension='txt', | |||
| mime_type='text/plain', | |||
| created_by=current_user.id, | |||
| created_at=datetime.datetime.utcnow(), | |||
| used=True, | |||
| used_by=current_user.id, | |||
| used_at=datetime.datetime.utcnow() | |||
| ) | |||
| db.session.add(upload_file) | |||
| db.session.commit() | |||
| return upload_file | |||
| @staticmethod | |||
| def get_file_preview(file_id: str) -> str: | |||
| # get file storage key | |||
| key = file_id + request.path | |||
| cached_response = cache.get(key) | |||
| if cached_response and time.time() - cached_response['timestamp'] < cache.ttl: | |||
| return cached_response['response'] | |||
| upload_file = db.session.query(UploadFile) \ | |||
| .filter(UploadFile.id == file_id) \ | |||
| .first() | |||
| if not upload_file: | |||
| raise NotFound("File not found") | |||
| # extract text from file | |||
| extension = upload_file.extension | |||
| if extension.lower() not in ALLOWED_EXTENSIONS: | |||
| raise UnsupportedFileTypeError() | |||
| text = FileExtractor.load(upload_file, return_text=True) | |||
| text = text[0:PREVIEW_WORDS_LIMIT] if text else '' | |||
| return text | |||
| @@ -35,6 +35,32 @@ class VectorService: | |||
| else: | |||
| index.add_texts([document]) | |||
| @classmethod | |||
| def multi_create_segment_vector(cls, pre_segment_data_list: list, dataset: Dataset): | |||
| documents = [] | |||
| for pre_segment_data in pre_segment_data_list: | |||
| segment = pre_segment_data['segment'] | |||
| document = Document( | |||
| page_content=segment.content, | |||
| metadata={ | |||
| "doc_id": segment.index_node_id, | |||
| "doc_hash": segment.index_node_hash, | |||
| "document_id": segment.document_id, | |||
| "dataset_id": segment.dataset_id, | |||
| } | |||
| ) | |||
| documents.append(document) | |||
| # save vector index | |||
| index = IndexBuilder.get_index(dataset, 'high_quality') | |||
| if index: | |||
| index.add_texts(documents, duplicate_check=True) | |||
| # save keyword index | |||
| keyword_index = IndexBuilder.get_index(dataset, 'economy') | |||
| if keyword_index: | |||
| keyword_index.multi_create_segment_keywords(pre_segment_data_list) | |||
| @classmethod | |||
| def update_segment_vector(cls, keywords: Optional[List[str]], segment: DocumentSegment, dataset: Dataset): | |||
| # update segment index task | |||