Signed-off-by: -LAN- <laipz8200@outlook.com>tags/1.8.0
| @@ -1,4 +1,4 @@ | |||
| from flask_restx import fields | |||
| from flask_restx import Api, Namespace, fields | |||
| from libs.helper import AppIconUrlField | |||
| @@ -10,6 +10,12 @@ parameters__system_parameters = { | |||
| "workflow_file_upload_limit": fields.Integer, | |||
| } | |||
| def build_system_parameters_model(api_or_ns: Api | Namespace): | |||
| """Build the system parameters model for the API or Namespace.""" | |||
| return api_or_ns.model("SystemParameters", parameters__system_parameters) | |||
| parameters_fields = { | |||
| "opening_statement": fields.String, | |||
| "suggested_questions": fields.Raw, | |||
| @@ -25,6 +31,14 @@ parameters_fields = { | |||
| "system_parameters": fields.Nested(parameters__system_parameters), | |||
| } | |||
| def build_parameters_model(api_or_ns: Api | Namespace): | |||
| """Build the parameters model for the API or Namespace.""" | |||
| copied_fields = parameters_fields.copy() | |||
| copied_fields["system_parameters"] = fields.Nested(build_system_parameters_model(api_or_ns)) | |||
| return api_or_ns.model("Parameters", copied_fields) | |||
| site_fields = { | |||
| "title": fields.String, | |||
| "chat_color_theme": fields.String, | |||
| @@ -41,3 +55,8 @@ site_fields = { | |||
| "show_workflow_steps": fields.Boolean, | |||
| "use_icon_as_answer_icon": fields.Boolean, | |||
| } | |||
| def build_site_model(api_or_ns: Api | Namespace): | |||
| """Build the site model for the API or Namespace.""" | |||
| return api_or_ns.model("Site", site_fields) | |||
| @@ -5,7 +5,7 @@ from werkzeug.exceptions import Forbidden | |||
| from controllers.console import api | |||
| from controllers.console.wraps import account_initialization_required, setup_required | |||
| from fields.tag_fields import tag_fields | |||
| from fields.tag_fields import dataset_tag_fields | |||
| from libs.login import login_required | |||
| from models.model import Tag | |||
| from services.tag_service import TagService | |||
| @@ -21,7 +21,7 @@ class TagListApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @marshal_with(tag_fields) | |||
| @marshal_with(dataset_tag_fields) | |||
| def get(self): | |||
| tag_type = request.args.get("type", type=str, default="") | |||
| keyword = request.args.get("keyword", default=None, type=str) | |||
| @@ -1,11 +1,23 @@ | |||
| from flask import Blueprint | |||
| from flask_restx import Namespace | |||
| from libs.external_api import ExternalApi | |||
| bp = Blueprint("service_api", __name__, url_prefix="/v1") | |||
| api = ExternalApi(bp) | |||
| api = ExternalApi( | |||
| bp, | |||
| version="1.0", | |||
| title="Service API", | |||
| description="API for application services", | |||
| doc="/docs", # Enable Swagger UI at /v1/docs | |||
| ) | |||
| service_api_ns = Namespace("service_api", description="Service operations") | |||
| from . import index | |||
| from .app import annotation, app, audio, completion, conversation, file, file_preview, message, site, workflow | |||
| from .dataset import dataset, document, hit_testing, metadata, segment, upload_file | |||
| from .workspace import models | |||
| api.add_namespace(service_api_ns) | |||
| @@ -1,28 +1,51 @@ | |||
| from typing import Literal | |||
| from flask import request | |||
| from flask_restx import Resource, marshal, marshal_with, reqparse | |||
| from flask_restx import Api, Namespace, Resource, fields, reqparse | |||
| from flask_restx.api import HTTPStatus | |||
| from werkzeug.exceptions import Forbidden | |||
| from controllers.service_api import api | |||
| from controllers.service_api import service_api_ns | |||
| from controllers.service_api.wraps import validate_app_token | |||
| from extensions.ext_redis import redis_client | |||
| from fields.annotation_fields import ( | |||
| annotation_fields, | |||
| ) | |||
| from fields.annotation_fields import annotation_fields, build_annotation_model | |||
| from libs.login import current_user | |||
| from models.model import App | |||
| from services.annotation_service import AppAnnotationService | |||
| # Define parsers for annotation API | |||
| annotation_create_parser = reqparse.RequestParser() | |||
| annotation_create_parser.add_argument("question", required=True, type=str, location="json", help="Annotation question") | |||
| annotation_create_parser.add_argument("answer", required=True, type=str, location="json", help="Annotation answer") | |||
| annotation_reply_action_parser = reqparse.RequestParser() | |||
| annotation_reply_action_parser.add_argument( | |||
| "score_threshold", required=True, type=float, location="json", help="Score threshold for annotation matching" | |||
| ) | |||
| annotation_reply_action_parser.add_argument( | |||
| "embedding_provider_name", required=True, type=str, location="json", help="Embedding provider name" | |||
| ) | |||
| annotation_reply_action_parser.add_argument( | |||
| "embedding_model_name", required=True, type=str, location="json", help="Embedding model name" | |||
| ) | |||
| @service_api_ns.route("/apps/annotation-reply/<string:action>") | |||
| class AnnotationReplyActionApi(Resource): | |||
| @service_api_ns.expect(annotation_reply_action_parser) | |||
| @service_api_ns.doc("annotation_reply_action") | |||
| @service_api_ns.doc(description="Enable or disable annotation reply feature") | |||
| @service_api_ns.doc(params={"action": "Action to perform: 'enable' or 'disable'"}) | |||
| @service_api_ns.doc( | |||
| responses={ | |||
| 200: "Action completed successfully", | |||
| 401: "Unauthorized - invalid API token", | |||
| } | |||
| ) | |||
| @validate_app_token | |||
| def post(self, app_model: App, action: Literal["enable", "disable"]): | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("score_threshold", required=True, type=float, location="json") | |||
| parser.add_argument("embedding_provider_name", required=True, type=str, location="json") | |||
| parser.add_argument("embedding_model_name", required=True, type=str, location="json") | |||
| args = parser.parse_args() | |||
| """Enable or disable annotation reply feature.""" | |||
| args = annotation_reply_action_parser.parse_args() | |||
| if action == "enable": | |||
| result = AppAnnotationService.enable_app_annotation(args, app_model.id) | |||
| elif action == "disable": | |||
| @@ -30,9 +53,21 @@ class AnnotationReplyActionApi(Resource): | |||
| return result, 200 | |||
| @service_api_ns.route("/apps/annotation-reply/<string:action>/status/<uuid:job_id>") | |||
| class AnnotationReplyActionStatusApi(Resource): | |||
| @service_api_ns.doc("get_annotation_reply_action_status") | |||
| @service_api_ns.doc(description="Get the status of an annotation reply action job") | |||
| @service_api_ns.doc(params={"action": "Action type", "job_id": "Job ID"}) | |||
| @service_api_ns.doc( | |||
| responses={ | |||
| 200: "Job status retrieved successfully", | |||
| 401: "Unauthorized - invalid API token", | |||
| 404: "Job not found", | |||
| } | |||
| ) | |||
| @validate_app_token | |||
| def get(self, app_model: App, job_id, action): | |||
| """Get the status of an annotation reply action job.""" | |||
| job_id = str(job_id) | |||
| app_annotation_job_key = f"{action}_app_annotation_job_{str(job_id)}" | |||
| cache_result = redis_client.get(app_annotation_job_key) | |||
| @@ -48,60 +83,111 @@ class AnnotationReplyActionStatusApi(Resource): | |||
| return {"job_id": job_id, "job_status": job_status, "error_msg": error_msg}, 200 | |||
| # Define annotation list response model | |||
| annotation_list_fields = { | |||
| "data": fields.List(fields.Nested(annotation_fields)), | |||
| "has_more": fields.Boolean, | |||
| "limit": fields.Integer, | |||
| "total": fields.Integer, | |||
| "page": fields.Integer, | |||
| } | |||
| def build_annotation_list_model(api_or_ns: Api | Namespace): | |||
| """Build the annotation list model for the API or Namespace.""" | |||
| copied_annotation_list_fields = annotation_list_fields.copy() | |||
| copied_annotation_list_fields["data"] = fields.List(fields.Nested(build_annotation_model(api_or_ns))) | |||
| return api_or_ns.model("AnnotationList", copied_annotation_list_fields) | |||
| @service_api_ns.route("/apps/annotations") | |||
| class AnnotationListApi(Resource): | |||
| @service_api_ns.doc("list_annotations") | |||
| @service_api_ns.doc(description="List annotations for the application") | |||
| @service_api_ns.doc( | |||
| responses={ | |||
| 200: "Annotations retrieved successfully", | |||
| 401: "Unauthorized - invalid API token", | |||
| } | |||
| ) | |||
| @validate_app_token | |||
| @service_api_ns.marshal_with(build_annotation_list_model(service_api_ns)) | |||
| def get(self, app_model: App): | |||
| """List annotations for the application.""" | |||
| page = request.args.get("page", default=1, type=int) | |||
| limit = request.args.get("limit", default=20, type=int) | |||
| keyword = request.args.get("keyword", default="", type=str) | |||
| annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(app_model.id, page, limit, keyword) | |||
| response = { | |||
| "data": marshal(annotation_list, annotation_fields), | |||
| return { | |||
| "data": annotation_list, | |||
| "has_more": len(annotation_list) == limit, | |||
| "limit": limit, | |||
| "total": total, | |||
| "page": page, | |||
| } | |||
| return response, 200 | |||
| @service_api_ns.expect(annotation_create_parser) | |||
| @service_api_ns.doc("create_annotation") | |||
| @service_api_ns.doc(description="Create a new annotation") | |||
| @service_api_ns.doc( | |||
| responses={ | |||
| 201: "Annotation created successfully", | |||
| 401: "Unauthorized - invalid API token", | |||
| } | |||
| ) | |||
| @validate_app_token | |||
| @marshal_with(annotation_fields) | |||
| @service_api_ns.marshal_with(build_annotation_model(service_api_ns), code=HTTPStatus.CREATED) | |||
| def post(self, app_model: App): | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("question", required=True, type=str, location="json") | |||
| parser.add_argument("answer", required=True, type=str, location="json") | |||
| args = parser.parse_args() | |||
| """Create a new annotation.""" | |||
| args = annotation_create_parser.parse_args() | |||
| annotation = AppAnnotationService.insert_app_annotation_directly(args, app_model.id) | |||
| return annotation | |||
| return annotation, 201 | |||
| @service_api_ns.route("/apps/annotations/<uuid:annotation_id>") | |||
| class AnnotationUpdateDeleteApi(Resource): | |||
| @service_api_ns.expect(annotation_create_parser) | |||
| @service_api_ns.doc("update_annotation") | |||
| @service_api_ns.doc(description="Update an existing annotation") | |||
| @service_api_ns.doc(params={"annotation_id": "Annotation ID"}) | |||
| @service_api_ns.doc( | |||
| responses={ | |||
| 200: "Annotation updated successfully", | |||
| 401: "Unauthorized - invalid API token", | |||
| 403: "Forbidden - insufficient permissions", | |||
| 404: "Annotation not found", | |||
| } | |||
| ) | |||
| @validate_app_token | |||
| @marshal_with(annotation_fields) | |||
| @service_api_ns.marshal_with(build_annotation_model(service_api_ns)) | |||
| def put(self, app_model: App, annotation_id): | |||
| """Update an existing annotation.""" | |||
| if not current_user.is_editor: | |||
| raise Forbidden() | |||
| annotation_id = str(annotation_id) | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("question", required=True, type=str, location="json") | |||
| parser.add_argument("answer", required=True, type=str, location="json") | |||
| args = parser.parse_args() | |||
| args = annotation_create_parser.parse_args() | |||
| annotation = AppAnnotationService.update_app_annotation_directly(args, app_model.id, annotation_id) | |||
| return annotation | |||
| @service_api_ns.doc("delete_annotation") | |||
| @service_api_ns.doc(description="Delete an annotation") | |||
| @service_api_ns.doc(params={"annotation_id": "Annotation ID"}) | |||
| @service_api_ns.doc( | |||
| responses={ | |||
| 204: "Annotation deleted successfully", | |||
| 401: "Unauthorized - invalid API token", | |||
| 403: "Forbidden - insufficient permissions", | |||
| 404: "Annotation not found", | |||
| } | |||
| ) | |||
| @validate_app_token | |||
| def delete(self, app_model: App, annotation_id): | |||
| """Delete an annotation.""" | |||
| if not current_user.is_editor: | |||
| raise Forbidden() | |||
| annotation_id = str(annotation_id) | |||
| AppAnnotationService.delete_app_annotation(app_model.id, annotation_id) | |||
| return {"result": "success"}, 204 | |||
| api.add_resource(AnnotationReplyActionApi, "/apps/annotation-reply/<string:action>") | |||
| api.add_resource(AnnotationReplyActionStatusApi, "/apps/annotation-reply/<string:action>/status/<uuid:job_id>") | |||
| api.add_resource(AnnotationListApi, "/apps/annotations") | |||
| api.add_resource(AnnotationUpdateDeleteApi, "/apps/annotations/<uuid:annotation_id>") | |||
| @@ -1,7 +1,7 @@ | |||
| from flask_restx import Resource, marshal_with | |||
| from flask_restx import Resource | |||
| from controllers.common import fields | |||
| from controllers.service_api import api | |||
| from controllers.common.fields import build_parameters_model | |||
| from controllers.service_api import service_api_ns | |||
| from controllers.service_api.app.error import AppUnavailableError | |||
| from controllers.service_api.wraps import validate_app_token | |||
| from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict | |||
| @@ -9,13 +9,26 @@ from models.model import App, AppMode | |||
| from services.app_service import AppService | |||
| @service_api_ns.route("/parameters") | |||
| class AppParameterApi(Resource): | |||
| """Resource for app variables.""" | |||
| @service_api_ns.doc("get_app_parameters") | |||
| @service_api_ns.doc(description="Retrieve application input parameters and configuration") | |||
| @service_api_ns.doc( | |||
| responses={ | |||
| 200: "Parameters retrieved successfully", | |||
| 401: "Unauthorized - invalid API token", | |||
| 404: "Application not found", | |||
| } | |||
| ) | |||
| @validate_app_token | |||
| @marshal_with(fields.parameters_fields) | |||
| @service_api_ns.marshal_with(build_parameters_model(service_api_ns)) | |||
| def get(self, app_model: App): | |||
| """Retrieve app parameters.""" | |||
| """Retrieve app parameters. | |||
| Returns the input form parameters and configuration for the application. | |||
| """ | |||
| if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}: | |||
| workflow = app_model.workflow | |||
| if workflow is None: | |||
| @@ -35,17 +48,43 @@ class AppParameterApi(Resource): | |||
| return get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form) | |||
| @service_api_ns.route("/meta") | |||
| class AppMetaApi(Resource): | |||
| @service_api_ns.doc("get_app_meta") | |||
| @service_api_ns.doc(description="Get application metadata") | |||
| @service_api_ns.doc( | |||
| responses={ | |||
| 200: "Metadata retrieved successfully", | |||
| 401: "Unauthorized - invalid API token", | |||
| 404: "Application not found", | |||
| } | |||
| ) | |||
| @validate_app_token | |||
| def get(self, app_model: App): | |||
| """Get app meta""" | |||
| """Get app metadata. | |||
| Returns metadata about the application including configuration and settings. | |||
| """ | |||
| return AppService().get_app_meta(app_model) | |||
| @service_api_ns.route("/info") | |||
| class AppInfoApi(Resource): | |||
| @service_api_ns.doc("get_app_info") | |||
| @service_api_ns.doc(description="Get basic application information") | |||
| @service_api_ns.doc( | |||
| responses={ | |||
| 200: "Application info retrieved successfully", | |||
| 401: "Unauthorized - invalid API token", | |||
| 404: "Application not found", | |||
| } | |||
| ) | |||
| @validate_app_token | |||
| def get(self, app_model: App): | |||
| """Get app information""" | |||
| """Get app information. | |||
| Returns basic information about the application including name, description, tags, and mode. | |||
| """ | |||
| tags = [tag.name for tag in app_model.tags] | |||
| return { | |||
| "name": app_model.name, | |||
| @@ -54,8 +93,3 @@ class AppInfoApi(Resource): | |||
| "mode": app_model.mode, | |||
| "author_name": app_model.author_name, | |||
| } | |||
| api.add_resource(AppParameterApi, "/parameters") | |||
| api.add_resource(AppMetaApi, "/meta") | |||
| api.add_resource(AppInfoApi, "/info") | |||
| @@ -5,7 +5,7 @@ from flask_restx import Resource, reqparse | |||
| from werkzeug.exceptions import InternalServerError | |||
| import services | |||
| from controllers.service_api import api | |||
| from controllers.service_api import service_api_ns | |||
| from controllers.service_api.app.error import ( | |||
| AppUnavailableError, | |||
| AudioTooLargeError, | |||
| @@ -30,9 +30,26 @@ from services.errors.audio import ( | |||
| ) | |||
| @service_api_ns.route("/audio-to-text") | |||
| class AudioApi(Resource): | |||
| @service_api_ns.doc("audio_to_text") | |||
| @service_api_ns.doc(description="Convert audio to text using speech-to-text") | |||
| @service_api_ns.doc( | |||
| responses={ | |||
| 200: "Audio successfully transcribed", | |||
| 400: "Bad request - no audio or invalid audio", | |||
| 401: "Unauthorized - invalid API token", | |||
| 413: "Audio file too large", | |||
| 415: "Unsupported audio type", | |||
| 500: "Internal server error", | |||
| } | |||
| ) | |||
| @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.FORM)) | |||
| def post(self, app_model: App, end_user: EndUser): | |||
| """Convert audio to text using speech-to-text. | |||
| Accepts an audio file upload and returns the transcribed text. | |||
| """ | |||
| file = request.files["file"] | |||
| try: | |||
| @@ -65,16 +82,35 @@ class AudioApi(Resource): | |||
| raise InternalServerError() | |||
| # Define parser for text-to-audio API | |||
| text_to_audio_parser = reqparse.RequestParser() | |||
| text_to_audio_parser.add_argument("message_id", type=str, required=False, location="json", help="Message ID") | |||
| text_to_audio_parser.add_argument("voice", type=str, location="json", help="Voice to use for TTS") | |||
| text_to_audio_parser.add_argument("text", type=str, location="json", help="Text to convert to audio") | |||
| text_to_audio_parser.add_argument("streaming", type=bool, location="json", help="Enable streaming response") | |||
| @service_api_ns.route("/text-to-audio") | |||
| class TextApi(Resource): | |||
| @service_api_ns.expect(text_to_audio_parser) | |||
| @service_api_ns.doc("text_to_audio") | |||
| @service_api_ns.doc(description="Convert text to audio using text-to-speech") | |||
| @service_api_ns.doc( | |||
| responses={ | |||
| 200: "Text successfully converted to audio", | |||
| 400: "Bad request - invalid parameters", | |||
| 401: "Unauthorized - invalid API token", | |||
| 500: "Internal server error", | |||
| } | |||
| ) | |||
| @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON)) | |||
| def post(self, app_model: App, end_user: EndUser): | |||
| """Convert text to audio using text-to-speech. | |||
| Converts the provided text to audio using the specified voice. | |||
| """ | |||
| try: | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("message_id", type=str, required=False, location="json") | |||
| parser.add_argument("voice", type=str, location="json") | |||
| parser.add_argument("text", type=str, location="json") | |||
| parser.add_argument("streaming", type=bool, location="json") | |||
| args = parser.parse_args() | |||
| args = text_to_audio_parser.parse_args() | |||
| message_id = args.get("message_id", None) | |||
| text = args.get("text", None) | |||
| @@ -108,7 +144,3 @@ class TextApi(Resource): | |||
| except Exception as e: | |||
| logging.exception("internal server error.") | |||
| raise InternalServerError() | |||
| api.add_resource(AudioApi, "/audio-to-text") | |||
| api.add_resource(TextApi, "/text-to-audio") | |||
| @@ -5,7 +5,7 @@ from flask_restx import Resource, reqparse | |||
| from werkzeug.exceptions import BadRequest, InternalServerError, NotFound | |||
| import services | |||
| from controllers.service_api import api | |||
| from controllers.service_api import service_api_ns | |||
| from controllers.service_api.app.error import ( | |||
| AppUnavailableError, | |||
| CompletionRequestError, | |||
| @@ -33,21 +33,68 @@ from services.app_generate_service import AppGenerateService | |||
| from services.errors.app import IsDraftWorkflowError, WorkflowIdFormatError, WorkflowNotFoundError | |||
| from services.errors.llm import InvokeRateLimitError | |||
| # Define parser for completion API | |||
| completion_parser = reqparse.RequestParser() | |||
| completion_parser.add_argument( | |||
| "inputs", type=dict, required=True, location="json", help="Input parameters for completion" | |||
| ) | |||
| completion_parser.add_argument("query", type=str, location="json", default="", help="The query string") | |||
| completion_parser.add_argument("files", type=list, required=False, location="json", help="List of file attachments") | |||
| completion_parser.add_argument( | |||
| "response_mode", type=str, choices=["blocking", "streaming"], location="json", help="Response mode" | |||
| ) | |||
| completion_parser.add_argument( | |||
| "retriever_from", type=str, required=False, default="dev", location="json", help="Retriever source" | |||
| ) | |||
| # Define parser for chat API | |||
| chat_parser = reqparse.RequestParser() | |||
| chat_parser.add_argument("inputs", type=dict, required=True, location="json", help="Input parameters for chat") | |||
| chat_parser.add_argument("query", type=str, required=True, location="json", help="The chat query") | |||
| chat_parser.add_argument("files", type=list, required=False, location="json", help="List of file attachments") | |||
| chat_parser.add_argument( | |||
| "response_mode", type=str, choices=["blocking", "streaming"], location="json", help="Response mode" | |||
| ) | |||
| chat_parser.add_argument("conversation_id", type=uuid_value, location="json", help="Existing conversation ID") | |||
| chat_parser.add_argument( | |||
| "retriever_from", type=str, required=False, default="dev", location="json", help="Retriever source" | |||
| ) | |||
| chat_parser.add_argument( | |||
| "auto_generate_name", | |||
| type=bool, | |||
| required=False, | |||
| default=True, | |||
| location="json", | |||
| help="Auto generate conversation name", | |||
| ) | |||
| chat_parser.add_argument("workflow_id", type=str, required=False, location="json", help="Workflow ID for advanced chat") | |||
| @service_api_ns.route("/completion-messages") | |||
| class CompletionApi(Resource): | |||
| @service_api_ns.expect(completion_parser) | |||
| @service_api_ns.doc("create_completion") | |||
| @service_api_ns.doc(description="Create a completion for the given prompt") | |||
| @service_api_ns.doc( | |||
| responses={ | |||
| 200: "Completion created successfully", | |||
| 400: "Bad request - invalid parameters", | |||
| 401: "Unauthorized - invalid API token", | |||
| 404: "Conversation not found", | |||
| 500: "Internal server error", | |||
| } | |||
| ) | |||
| @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True)) | |||
| def post(self, app_model: App, end_user: EndUser): | |||
| """Create a completion for the given prompt. | |||
| This endpoint generates a completion based on the provided inputs and query. | |||
| Supports both blocking and streaming response modes. | |||
| """ | |||
| if app_model.mode != "completion": | |||
| raise AppUnavailableError() | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("inputs", type=dict, required=True, location="json") | |||
| parser.add_argument("query", type=str, location="json", default="") | |||
| parser.add_argument("files", type=list, required=False, location="json") | |||
| parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json") | |||
| parser.add_argument("retriever_from", type=str, required=False, default="dev", location="json") | |||
| args = parser.parse_args() | |||
| args = completion_parser.parse_args() | |||
| external_trace_id = get_external_trace_id(request) | |||
| if external_trace_id: | |||
| args["external_trace_id"] = external_trace_id | |||
| @@ -88,9 +135,21 @@ class CompletionApi(Resource): | |||
| raise InternalServerError() | |||
| @service_api_ns.route("/completion-messages/<string:task_id>/stop") | |||
| class CompletionStopApi(Resource): | |||
| @service_api_ns.doc("stop_completion") | |||
| @service_api_ns.doc(description="Stop a running completion task") | |||
| @service_api_ns.doc(params={"task_id": "The ID of the task to stop"}) | |||
| @service_api_ns.doc( | |||
| responses={ | |||
| 200: "Task stopped successfully", | |||
| 401: "Unauthorized - invalid API token", | |||
| 404: "Task not found", | |||
| } | |||
| ) | |||
| @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True)) | |||
| def post(self, app_model: App, end_user: EndUser, task_id): | |||
| def post(self, app_model: App, end_user: EndUser, task_id: str): | |||
| """Stop a running completion task.""" | |||
| if app_model.mode != "completion": | |||
| raise AppUnavailableError() | |||
| @@ -99,23 +158,33 @@ class CompletionStopApi(Resource): | |||
| return {"result": "success"}, 200 | |||
| @service_api_ns.route("/chat-messages") | |||
| class ChatApi(Resource): | |||
| @service_api_ns.expect(chat_parser) | |||
| @service_api_ns.doc("create_chat_message") | |||
| @service_api_ns.doc(description="Send a message in a chat conversation") | |||
| @service_api_ns.doc( | |||
| responses={ | |||
| 200: "Message sent successfully", | |||
| 400: "Bad request - invalid parameters or workflow issues", | |||
| 401: "Unauthorized - invalid API token", | |||
| 404: "Conversation or workflow not found", | |||
| 429: "Rate limit exceeded", | |||
| 500: "Internal server error", | |||
| } | |||
| ) | |||
| @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True)) | |||
| def post(self, app_model: App, end_user: EndUser): | |||
| """Send a message in a chat conversation. | |||
| This endpoint handles chat messages for chat, agent chat, and advanced chat applications. | |||
| Supports conversation management and both blocking and streaming response modes. | |||
| """ | |||
| app_mode = AppMode.value_of(app_model.mode) | |||
| if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: | |||
| raise NotChatAppError() | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("inputs", type=dict, required=True, location="json") | |||
| parser.add_argument("query", type=str, required=True, location="json") | |||
| parser.add_argument("files", type=list, required=False, location="json") | |||
| parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json") | |||
| parser.add_argument("conversation_id", type=uuid_value, location="json") | |||
| parser.add_argument("retriever_from", type=str, required=False, default="dev", location="json") | |||
| parser.add_argument("auto_generate_name", type=bool, required=False, default=True, location="json") | |||
| parser.add_argument("workflow_id", type=str, required=False, location="json") | |||
| args = parser.parse_args() | |||
| args = chat_parser.parse_args() | |||
| external_trace_id = get_external_trace_id(request) | |||
| if external_trace_id: | |||
| @@ -159,9 +228,21 @@ class ChatApi(Resource): | |||
| raise InternalServerError() | |||
| @service_api_ns.route("/chat-messages/<string:task_id>/stop") | |||
| class ChatStopApi(Resource): | |||
| @service_api_ns.doc("stop_chat_message") | |||
| @service_api_ns.doc(description="Stop a running chat message generation") | |||
| @service_api_ns.doc(params={"task_id": "The ID of the task to stop"}) | |||
| @service_api_ns.doc( | |||
| responses={ | |||
| 200: "Task stopped successfully", | |||
| 401: "Unauthorized - invalid API token", | |||
| 404: "Task not found", | |||
| } | |||
| ) | |||
| @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True)) | |||
| def post(self, app_model: App, end_user: EndUser, task_id): | |||
| def post(self, app_model: App, end_user: EndUser, task_id: str): | |||
| """Stop a running chat message generation.""" | |||
| app_mode = AppMode.value_of(app_model.mode) | |||
| if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: | |||
| raise NotChatAppError() | |||
| @@ -169,9 +250,3 @@ class ChatStopApi(Resource): | |||
| AppQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id) | |||
| return {"result": "success"}, 200 | |||
| api.add_resource(CompletionApi, "/completion-messages") | |||
| api.add_resource(CompletionStopApi, "/completion-messages/<string:task_id>/stop") | |||
| api.add_resource(ChatApi, "/chat-messages") | |||
| api.add_resource(ChatStopApi, "/chat-messages/<string:task_id>/stop") | |||
| @@ -1,48 +1,97 @@ | |||
| from flask_restx import Resource, marshal_with, reqparse | |||
| from flask_restx import Resource, reqparse | |||
| from flask_restx.inputs import int_range | |||
| from sqlalchemy.orm import Session | |||
| from werkzeug.exceptions import BadRequest, NotFound | |||
| import services | |||
| from controllers.service_api import api | |||
| from controllers.service_api import service_api_ns | |||
| from controllers.service_api.app.error import NotChatAppError | |||
| from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token | |||
| from core.app.entities.app_invoke_entities import InvokeFrom | |||
| from extensions.ext_database import db | |||
| from fields.conversation_fields import ( | |||
| conversation_delete_fields, | |||
| conversation_infinite_scroll_pagination_fields, | |||
| simple_conversation_fields, | |||
| build_conversation_delete_model, | |||
| build_conversation_infinite_scroll_pagination_model, | |||
| build_simple_conversation_model, | |||
| ) | |||
| from fields.conversation_variable_fields import ( | |||
| conversation_variable_fields, | |||
| conversation_variable_infinite_scroll_pagination_fields, | |||
| build_conversation_variable_infinite_scroll_pagination_model, | |||
| build_conversation_variable_model, | |||
| ) | |||
| from libs.helper import uuid_value | |||
| from models.model import App, AppMode, EndUser | |||
| from services.conversation_service import ConversationService | |||
| # Define parsers for conversation APIs | |||
| conversation_list_parser = reqparse.RequestParser() | |||
| conversation_list_parser.add_argument( | |||
| "last_id", type=uuid_value, location="args", help="Last conversation ID for pagination" | |||
| ) | |||
| conversation_list_parser.add_argument( | |||
| "limit", | |||
| type=int_range(1, 100), | |||
| required=False, | |||
| default=20, | |||
| location="args", | |||
| help="Number of conversations to return", | |||
| ) | |||
| conversation_list_parser.add_argument( | |||
| "sort_by", | |||
| type=str, | |||
| choices=["created_at", "-created_at", "updated_at", "-updated_at"], | |||
| required=False, | |||
| default="-updated_at", | |||
| location="args", | |||
| help="Sort order for conversations", | |||
| ) | |||
| conversation_rename_parser = reqparse.RequestParser() | |||
| conversation_rename_parser.add_argument("name", type=str, required=False, location="json", help="New conversation name") | |||
| conversation_rename_parser.add_argument( | |||
| "auto_generate", type=bool, required=False, default=False, location="json", help="Auto-generate conversation name" | |||
| ) | |||
| conversation_variables_parser = reqparse.RequestParser() | |||
| conversation_variables_parser.add_argument( | |||
| "last_id", type=uuid_value, location="args", help="Last variable ID for pagination" | |||
| ) | |||
| conversation_variables_parser.add_argument( | |||
| "limit", type=int_range(1, 100), required=False, default=20, location="args", help="Number of variables to return" | |||
| ) | |||
| conversation_variable_update_parser = reqparse.RequestParser() | |||
| # using lambda is for passing the already-typed value without modification | |||
| # if no lambda, it will be converted to string | |||
| # the string cannot be converted using json.loads | |||
| conversation_variable_update_parser.add_argument( | |||
| "value", required=True, location="json", type=lambda x: x, help="New value for the conversation variable" | |||
| ) | |||
| @service_api_ns.route("/conversations") | |||
| class ConversationApi(Resource): | |||
| @service_api_ns.expect(conversation_list_parser) | |||
| @service_api_ns.doc("list_conversations") | |||
| @service_api_ns.doc(description="List all conversations for the current user") | |||
| @service_api_ns.doc( | |||
| responses={ | |||
| 200: "Conversations retrieved successfully", | |||
| 401: "Unauthorized - invalid API token", | |||
| 404: "Last conversation not found", | |||
| } | |||
| ) | |||
| @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY)) | |||
| @marshal_with(conversation_infinite_scroll_pagination_fields) | |||
| @service_api_ns.marshal_with(build_conversation_infinite_scroll_pagination_model(service_api_ns)) | |||
| def get(self, app_model: App, end_user: EndUser): | |||
| """List all conversations for the current user. | |||
| Supports pagination using last_id and limit parameters. | |||
| """ | |||
| app_mode = AppMode.value_of(app_model.mode) | |||
| if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: | |||
| raise NotChatAppError() | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("last_id", type=uuid_value, location="args") | |||
| parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") | |||
| parser.add_argument( | |||
| "sort_by", | |||
| type=str, | |||
| choices=["created_at", "-created_at", "updated_at", "-updated_at"], | |||
| required=False, | |||
| default="-updated_at", | |||
| location="args", | |||
| ) | |||
| args = parser.parse_args() | |||
| args = conversation_list_parser.parse_args() | |||
| try: | |||
| with Session(db.engine) as session: | |||
| @@ -59,10 +108,22 @@ class ConversationApi(Resource): | |||
| raise NotFound("Last Conversation Not Exists.") | |||
| @service_api_ns.route("/conversations/<uuid:c_id>") | |||
| class ConversationDetailApi(Resource): | |||
| @service_api_ns.doc("delete_conversation") | |||
| @service_api_ns.doc(description="Delete a specific conversation") | |||
| @service_api_ns.doc(params={"c_id": "Conversation ID"}) | |||
| @service_api_ns.doc( | |||
| responses={ | |||
| 204: "Conversation deleted successfully", | |||
| 401: "Unauthorized - invalid API token", | |||
| 404: "Conversation not found", | |||
| } | |||
| ) | |||
| @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON)) | |||
| @marshal_with(conversation_delete_fields) | |||
| @service_api_ns.marshal_with(build_conversation_delete_model(service_api_ns), code=204) | |||
| def delete(self, app_model: App, end_user: EndUser, c_id): | |||
| """Delete a specific conversation.""" | |||
| app_mode = AppMode.value_of(app_model.mode) | |||
| if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: | |||
| raise NotChatAppError() | |||
| @@ -76,20 +137,30 @@ class ConversationDetailApi(Resource): | |||
| return {"result": "success"}, 204 | |||
| @service_api_ns.route("/conversations/<uuid:c_id>/name") | |||
| class ConversationRenameApi(Resource): | |||
| @service_api_ns.expect(conversation_rename_parser) | |||
| @service_api_ns.doc("rename_conversation") | |||
| @service_api_ns.doc(description="Rename a conversation or auto-generate a name") | |||
| @service_api_ns.doc(params={"c_id": "Conversation ID"}) | |||
| @service_api_ns.doc( | |||
| responses={ | |||
| 200: "Conversation renamed successfully", | |||
| 401: "Unauthorized - invalid API token", | |||
| 404: "Conversation not found", | |||
| } | |||
| ) | |||
| @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON)) | |||
| @marshal_with(simple_conversation_fields) | |||
| @service_api_ns.marshal_with(build_simple_conversation_model(service_api_ns)) | |||
| def post(self, app_model: App, end_user: EndUser, c_id): | |||
| """Rename a conversation or auto-generate a name.""" | |||
| app_mode = AppMode.value_of(app_model.mode) | |||
| if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: | |||
| raise NotChatAppError() | |||
| conversation_id = str(c_id) | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("name", type=str, required=False, location="json") | |||
| parser.add_argument("auto_generate", type=bool, required=False, default=False, location="json") | |||
| args = parser.parse_args() | |||
| args = conversation_rename_parser.parse_args() | |||
| try: | |||
| return ConversationService.rename(app_model, conversation_id, end_user, args["name"], args["auto_generate"]) | |||
| @@ -97,10 +168,26 @@ class ConversationRenameApi(Resource): | |||
| raise NotFound("Conversation Not Exists.") | |||
| @service_api_ns.route("/conversations/<uuid:c_id>/variables") | |||
| class ConversationVariablesApi(Resource): | |||
| @service_api_ns.expect(conversation_variables_parser) | |||
| @service_api_ns.doc("list_conversation_variables") | |||
| @service_api_ns.doc(description="List all variables for a conversation") | |||
| @service_api_ns.doc(params={"c_id": "Conversation ID"}) | |||
| @service_api_ns.doc( | |||
| responses={ | |||
| 200: "Variables retrieved successfully", | |||
| 401: "Unauthorized - invalid API token", | |||
| 404: "Conversation not found", | |||
| } | |||
| ) | |||
| @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY)) | |||
| @marshal_with(conversation_variable_infinite_scroll_pagination_fields) | |||
| @service_api_ns.marshal_with(build_conversation_variable_infinite_scroll_pagination_model(service_api_ns)) | |||
| def get(self, app_model: App, end_user: EndUser, c_id): | |||
| """List all variables for a conversation. | |||
| Conversational variables are only available for chat applications. | |||
| """ | |||
| # conversational variable only for chat app | |||
| app_mode = AppMode.value_of(app_model.mode) | |||
| if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: | |||
| @@ -108,10 +195,7 @@ class ConversationVariablesApi(Resource): | |||
| conversation_id = str(c_id) | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("last_id", type=uuid_value, location="args") | |||
| parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") | |||
| args = parser.parse_args() | |||
| args = conversation_variables_parser.parse_args() | |||
| try: | |||
| return ConversationService.get_conversational_variable( | |||
| @@ -121,11 +205,28 @@ class ConversationVariablesApi(Resource): | |||
| raise NotFound("Conversation Not Exists.") | |||
| @service_api_ns.route("/conversations/<uuid:c_id>/variables/<uuid:variable_id>") | |||
| class ConversationVariableDetailApi(Resource): | |||
| @service_api_ns.expect(conversation_variable_update_parser) | |||
| @service_api_ns.doc("update_conversation_variable") | |||
| @service_api_ns.doc(description="Update a conversation variable's value") | |||
| @service_api_ns.doc(params={"c_id": "Conversation ID", "variable_id": "Variable ID"}) | |||
| @service_api_ns.doc( | |||
| responses={ | |||
| 200: "Variable updated successfully", | |||
| 400: "Bad request - type mismatch", | |||
| 401: "Unauthorized - invalid API token", | |||
| 404: "Conversation or variable not found", | |||
| } | |||
| ) | |||
| @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON)) | |||
| @marshal_with(conversation_variable_fields) | |||
| @service_api_ns.marshal_with(build_conversation_variable_model(service_api_ns)) | |||
| def put(self, app_model: App, end_user: EndUser, c_id, variable_id): | |||
| """Update a conversation variable's value""" | |||
| """Update a conversation variable's value. | |||
| Allows updating the value of a specific conversation variable. | |||
| The value must match the variable's expected type. | |||
| """ | |||
| app_mode = AppMode.value_of(app_model.mode) | |||
| if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: | |||
| raise NotChatAppError() | |||
| @@ -133,12 +234,7 @@ class ConversationVariableDetailApi(Resource): | |||
| conversation_id = str(c_id) | |||
| variable_id = str(variable_id) | |||
| parser = reqparse.RequestParser() | |||
| # using lambda is for passing the already-typed value without modification | |||
| # if no lambda, it will be converted to string | |||
| # the string cannot be converted using json.loads | |||
| parser.add_argument("value", required=True, location="json", type=lambda x: x) | |||
| args = parser.parse_args() | |||
| args = conversation_variable_update_parser.parse_args() | |||
| try: | |||
| return ConversationService.update_conversation_variable( | |||
| @@ -150,15 +246,3 @@ class ConversationVariableDetailApi(Resource): | |||
| raise NotFound("Conversation Variable Not Exists.") | |||
| except services.errors.conversation.ConversationVariableTypeMismatchError as e: | |||
| raise BadRequest(str(e)) | |||
| api.add_resource(ConversationRenameApi, "/conversations/<uuid:c_id>/name", endpoint="conversation_name") | |||
| api.add_resource(ConversationApi, "/conversations") | |||
| api.add_resource(ConversationDetailApi, "/conversations/<uuid:c_id>", endpoint="conversation_detail") | |||
| api.add_resource(ConversationVariablesApi, "/conversations/<uuid:c_id>/variables", endpoint="conversation_variables") | |||
| api.add_resource( | |||
| ConversationVariableDetailApi, | |||
| "/conversations/<uuid:c_id>/variables/<uuid:variable_id>", | |||
| endpoint="conversation_variable_detail", | |||
| methods=["PUT"], | |||
| ) | |||
| @@ -1,5 +1,6 @@ | |||
| from flask import request | |||
| from flask_restx import Resource, marshal_with | |||
| from flask_restx import Resource | |||
| from flask_restx.api import HTTPStatus | |||
| import services | |||
| from controllers.common.errors import ( | |||
| @@ -9,17 +10,33 @@ from controllers.common.errors import ( | |||
| TooManyFilesError, | |||
| UnsupportedFileTypeError, | |||
| ) | |||
| from controllers.service_api import api | |||
| from controllers.service_api import service_api_ns | |||
| from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token | |||
| from fields.file_fields import file_fields | |||
| from fields.file_fields import build_file_model | |||
| from models.model import App, EndUser | |||
| from services.file_service import FileService | |||
| @service_api_ns.route("/files/upload") | |||
| class FileApi(Resource): | |||
| @service_api_ns.doc("upload_file") | |||
| @service_api_ns.doc(description="Upload a file for use in conversations") | |||
| @service_api_ns.doc( | |||
| responses={ | |||
| 201: "File uploaded successfully", | |||
| 400: "Bad request - no file or invalid file", | |||
| 401: "Unauthorized - invalid API token", | |||
| 413: "File too large", | |||
| 415: "Unsupported file type", | |||
| } | |||
| ) | |||
| @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.FORM)) | |||
| @marshal_with(file_fields) | |||
| @service_api_ns.marshal_with(build_file_model(service_api_ns), code=HTTPStatus.CREATED) | |||
| def post(self, app_model: App, end_user: EndUser): | |||
| """Upload a file for use in conversations. | |||
| Accepts a single file upload via multipart/form-data. | |||
| """ | |||
| # check file | |||
| if "file" not in request.files: | |||
| raise NoFileUploadedError() | |||
| @@ -47,6 +64,3 @@ class FileApi(Resource): | |||
| raise UnsupportedFileTypeError() | |||
| return upload_file, 201 | |||
| api.add_resource(FileApi, "/files/upload") | |||
| @@ -4,7 +4,7 @@ from urllib.parse import quote | |||
| from flask import Response | |||
| from flask_restx import Resource, reqparse | |||
| from controllers.service_api import api | |||
| from controllers.service_api import service_api_ns | |||
| from controllers.service_api.app.error import ( | |||
| FileAccessDeniedError, | |||
| FileNotFoundError, | |||
| @@ -17,6 +17,14 @@ from models.model import App, EndUser, Message, MessageFile, UploadFile | |||
| logger = logging.getLogger(__name__) | |||
| # Define parser for file preview API | |||
| file_preview_parser = reqparse.RequestParser() | |||
| file_preview_parser.add_argument( | |||
| "as_attachment", type=bool, required=False, default=False, location="args", help="Download as attachment" | |||
| ) | |||
| @service_api_ns.route("/files/<uuid:file_id>/preview") | |||
| class FilePreviewApi(Resource): | |||
| """ | |||
| Service API File Preview endpoint | |||
| @@ -25,33 +33,30 @@ class FilePreviewApi(Resource): | |||
| Files can only be accessed if they belong to messages within the requesting app's context. | |||
| """ | |||
| @service_api_ns.expect(file_preview_parser) | |||
| @service_api_ns.doc("preview_file") | |||
| @service_api_ns.doc(description="Preview or download a file uploaded via Service API") | |||
| @service_api_ns.doc(params={"file_id": "UUID of the file to preview"}) | |||
| @service_api_ns.doc( | |||
| responses={ | |||
| 200: "File retrieved successfully", | |||
| 401: "Unauthorized - invalid API token", | |||
| 403: "Forbidden - file access denied", | |||
| 404: "File not found", | |||
| } | |||
| ) | |||
| @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY)) | |||
| def get(self, app_model: App, end_user: EndUser, file_id: str): | |||
| """ | |||
| Preview/Download a file that was uploaded via Service API | |||
| Args: | |||
| app_model: The authenticated app model | |||
| end_user: The authenticated end user (optional) | |||
| file_id: UUID of the file to preview | |||
| Query Parameters: | |||
| user: Optional user identifier | |||
| as_attachment: Boolean, whether to download as attachment (default: false) | |||
| Preview/Download a file that was uploaded via Service API. | |||
| Returns: | |||
| Stream response with file content | |||
| Raises: | |||
| FileNotFoundError: File does not exist | |||
| FileAccessDeniedError: File access denied (not owned by app) | |||
| Provides secure file preview/download functionality. | |||
| Files can only be accessed if they belong to messages within the requesting app's context. | |||
| """ | |||
| file_id = str(file_id) | |||
| # Parse query parameters | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("as_attachment", type=bool, required=False, default=False, location="args") | |||
| args = parser.parse_args() | |||
| args = file_preview_parser.parse_args() | |||
| # Validate file ownership and get file objects | |||
| message_file, upload_file = self._validate_file_ownership(file_id, app_model.id) | |||
| @@ -180,7 +185,3 @@ class FilePreviewApi(Resource): | |||
| response.headers["Cache-Control"] = "public, max-age=3600" # Cache for 1 hour | |||
| return response | |||
| # Register the API endpoint | |||
| api.add_resource(FilePreviewApi, "/files/<uuid:file_id>/preview") | |||
| @@ -1,17 +1,17 @@ | |||
| import json | |||
| import logging | |||
| from flask_restx import Resource, fields, marshal_with, reqparse | |||
| from flask_restx import Api, Namespace, Resource, fields, reqparse | |||
| from flask_restx.inputs import int_range | |||
| from werkzeug.exceptions import BadRequest, InternalServerError, NotFound | |||
| import services | |||
| from controllers.service_api import api | |||
| from controllers.service_api import service_api_ns | |||
| from controllers.service_api.app.error import NotChatAppError | |||
| from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token | |||
| from core.app.entities.app_invoke_entities import InvokeFrom | |||
| from fields.conversation_fields import message_file_fields | |||
| from fields.message_fields import agent_thought_fields, feedback_fields | |||
| from fields.conversation_fields import build_message_file_model | |||
| from fields.message_fields import build_agent_thought_model, build_feedback_model | |||
| from fields.raws import FilesContainedField | |||
| from libs.helper import TimestampField, uuid_value | |||
| from models.model import App, AppMode, EndUser | |||
| @@ -22,8 +22,37 @@ from services.errors.message import ( | |||
| ) | |||
| from services.message_service import MessageService | |||
| # Define parsers for message APIs | |||
| message_list_parser = reqparse.RequestParser() | |||
| message_list_parser.add_argument( | |||
| "conversation_id", required=True, type=uuid_value, location="args", help="Conversation ID" | |||
| ) | |||
| message_list_parser.add_argument("first_id", type=uuid_value, location="args", help="First message ID for pagination") | |||
| message_list_parser.add_argument( | |||
| "limit", type=int_range(1, 100), required=False, default=20, location="args", help="Number of messages to return" | |||
| ) | |||
| class MessageListApi(Resource): | |||
| message_feedback_parser = reqparse.RequestParser() | |||
| message_feedback_parser.add_argument( | |||
| "rating", type=str, choices=["like", "dislike", None], location="json", help="Feedback rating" | |||
| ) | |||
| message_feedback_parser.add_argument("content", type=str, location="json", help="Feedback content") | |||
| feedback_list_parser = reqparse.RequestParser() | |||
| feedback_list_parser.add_argument("page", type=int, default=1, location="args", help="Page number") | |||
| feedback_list_parser.add_argument( | |||
| "limit", type=int_range(1, 101), required=False, default=20, location="args", help="Number of feedbacks per page" | |||
| ) | |||
| def build_message_model(api_or_ns: Api | Namespace): | |||
| """Build the message model for the API or Namespace.""" | |||
| # First build the nested models | |||
| feedback_model = build_feedback_model(api_or_ns) | |||
| agent_thought_model = build_agent_thought_model(api_or_ns) | |||
| message_file_model = build_message_file_model(api_or_ns) | |||
| # Then build the message fields with nested models | |||
| message_fields = { | |||
| "id": fields.String, | |||
| "conversation_id": fields.String, | |||
| @@ -31,37 +60,58 @@ class MessageListApi(Resource): | |||
| "inputs": FilesContainedField, | |||
| "query": fields.String, | |||
| "answer": fields.String(attribute="re_sign_file_url_answer"), | |||
| "message_files": fields.List(fields.Nested(message_file_fields)), | |||
| "feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True), | |||
| "message_files": fields.List(fields.Nested(message_file_model)), | |||
| "feedback": fields.Nested(feedback_model, attribute="user_feedback", allow_null=True), | |||
| "retriever_resources": fields.Raw( | |||
| attribute=lambda obj: json.loads(obj.message_metadata).get("retriever_resources", []) | |||
| if obj.message_metadata | |||
| else [] | |||
| ), | |||
| "created_at": TimestampField, | |||
| "agent_thoughts": fields.List(fields.Nested(agent_thought_fields)), | |||
| "agent_thoughts": fields.List(fields.Nested(agent_thought_model)), | |||
| "status": fields.String, | |||
| "error": fields.String, | |||
| } | |||
| return api_or_ns.model("Message", message_fields) | |||
| def build_message_infinite_scroll_pagination_model(api_or_ns: Api | Namespace): | |||
| """Build the message infinite scroll pagination model for the API or Namespace.""" | |||
| # Build the nested message model first | |||
| message_model = build_message_model(api_or_ns) | |||
| message_infinite_scroll_pagination_fields = { | |||
| "limit": fields.Integer, | |||
| "has_more": fields.Boolean, | |||
| "data": fields.List(fields.Nested(message_fields)), | |||
| "data": fields.List(fields.Nested(message_model)), | |||
| } | |||
| return api_or_ns.model("MessageInfiniteScrollPagination", message_infinite_scroll_pagination_fields) | |||
| @service_api_ns.route("/messages") | |||
| class MessageListApi(Resource): | |||
| @service_api_ns.expect(message_list_parser) | |||
| @service_api_ns.doc("list_messages") | |||
| @service_api_ns.doc(description="List messages in a conversation") | |||
| @service_api_ns.doc( | |||
| responses={ | |||
| 200: "Messages retrieved successfully", | |||
| 401: "Unauthorized - invalid API token", | |||
| 404: "Conversation or first message not found", | |||
| } | |||
| ) | |||
| @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY)) | |||
| @marshal_with(message_infinite_scroll_pagination_fields) | |||
| @service_api_ns.marshal_with(build_message_infinite_scroll_pagination_model(service_api_ns)) | |||
| def get(self, app_model: App, end_user: EndUser): | |||
| """List messages in a conversation. | |||
| Retrieves messages with pagination support using first_id. | |||
| """ | |||
| app_mode = AppMode.value_of(app_model.mode) | |||
| if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: | |||
| raise NotChatAppError() | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("conversation_id", required=True, type=uuid_value, location="args") | |||
| parser.add_argument("first_id", type=uuid_value, location="args") | |||
| parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") | |||
| args = parser.parse_args() | |||
| args = message_list_parser.parse_args() | |||
| try: | |||
| return MessageService.pagination_by_first_id( | |||
| @@ -73,15 +123,28 @@ class MessageListApi(Resource): | |||
| raise NotFound("First Message Not Exists.") | |||
| @service_api_ns.route("/messages/<uuid:message_id>/feedbacks") | |||
| class MessageFeedbackApi(Resource): | |||
| @service_api_ns.expect(message_feedback_parser) | |||
| @service_api_ns.doc("create_message_feedback") | |||
| @service_api_ns.doc(description="Submit feedback for a message") | |||
| @service_api_ns.doc(params={"message_id": "Message ID"}) | |||
| @service_api_ns.doc( | |||
| responses={ | |||
| 200: "Feedback submitted successfully", | |||
| 401: "Unauthorized - invalid API token", | |||
| 404: "Message not found", | |||
| } | |||
| ) | |||
| @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True)) | |||
| def post(self, app_model: App, end_user: EndUser, message_id): | |||
| """Submit feedback for a message. | |||
| Allows users to rate messages as like/dislike and provide optional feedback content. | |||
| """ | |||
| message_id = str(message_id) | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("rating", type=str, choices=["like", "dislike", None], location="json") | |||
| parser.add_argument("content", type=str, location="json") | |||
| args = parser.parse_args() | |||
| args = message_feedback_parser.parse_args() | |||
| try: | |||
| MessageService.create_feedback( | |||
| @@ -97,21 +160,48 @@ class MessageFeedbackApi(Resource): | |||
| return {"result": "success"} | |||
| @service_api_ns.route("/app/feedbacks") | |||
| class AppGetFeedbacksApi(Resource): | |||
| @service_api_ns.expect(feedback_list_parser) | |||
| @service_api_ns.doc("get_app_feedbacks") | |||
| @service_api_ns.doc(description="Get all feedbacks for the application") | |||
| @service_api_ns.doc( | |||
| responses={ | |||
| 200: "Feedbacks retrieved successfully", | |||
| 401: "Unauthorized - invalid API token", | |||
| } | |||
| ) | |||
| @validate_app_token | |||
| def get(self, app_model: App): | |||
| """Get All Feedbacks of an app""" | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("page", type=int, default=1, location="args") | |||
| parser.add_argument("limit", type=int_range(1, 101), required=False, default=20, location="args") | |||
| args = parser.parse_args() | |||
| """Get all feedbacks for the application. | |||
| Returns paginated list of all feedback submitted for messages in this app. | |||
| """ | |||
| args = feedback_list_parser.parse_args() | |||
| feedbacks = MessageService.get_all_messages_feedbacks(app_model, page=args["page"], limit=args["limit"]) | |||
| return {"data": feedbacks} | |||
| @service_api_ns.route("/messages/<uuid:message_id>/suggested") | |||
| class MessageSuggestedApi(Resource): | |||
| @service_api_ns.doc("get_suggested_questions") | |||
| @service_api_ns.doc(description="Get suggested follow-up questions for a message") | |||
| @service_api_ns.doc(params={"message_id": "Message ID"}) | |||
| @service_api_ns.doc( | |||
| responses={ | |||
| 200: "Suggested questions retrieved successfully", | |||
| 400: "Suggested questions feature is disabled", | |||
| 401: "Unauthorized - invalid API token", | |||
| 404: "Message not found", | |||
| 500: "Internal server error", | |||
| } | |||
| ) | |||
| @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY, required=True)) | |||
| def get(self, app_model: App, end_user: EndUser, message_id): | |||
| """Get suggested follow-up questions for a message. | |||
| Returns AI-generated follow-up questions based on the message content. | |||
| """ | |||
| message_id = str(message_id) | |||
| app_mode = AppMode.value_of(app_model.mode) | |||
| if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: | |||
| @@ -130,9 +220,3 @@ class MessageSuggestedApi(Resource): | |||
| raise InternalServerError() | |||
| return {"result": "success", "data": questions} | |||
| api.add_resource(MessageListApi, "/messages") | |||
| api.add_resource(MessageFeedbackApi, "/messages/<uuid:message_id>/feedbacks") | |||
| api.add_resource(MessageSuggestedApi, "/messages/<uuid:message_id>/suggested") | |||
| api.add_resource(AppGetFeedbacksApi, "/app/feedbacks") | |||
| @@ -1,30 +1,41 @@ | |||
| from flask_restx import Resource, marshal_with | |||
| from flask_restx import Resource | |||
| from werkzeug.exceptions import Forbidden | |||
| from controllers.common import fields | |||
| from controllers.service_api import api | |||
| from controllers.common.fields import build_site_model | |||
| from controllers.service_api import service_api_ns | |||
| from controllers.service_api.wraps import validate_app_token | |||
| from extensions.ext_database import db | |||
| from models.account import TenantStatus | |||
| from models.model import App, Site | |||
| @service_api_ns.route("/site") | |||
| class AppSiteApi(Resource): | |||
| """Resource for app sites.""" | |||
| @service_api_ns.doc("get_app_site") | |||
| @service_api_ns.doc(description="Get application site configuration") | |||
| @service_api_ns.doc( | |||
| responses={ | |||
| 200: "Site configuration retrieved successfully", | |||
| 401: "Unauthorized - invalid API token", | |||
| 403: "Forbidden - site not found or tenant archived", | |||
| } | |||
| ) | |||
| @validate_app_token | |||
| @marshal_with(fields.site_fields) | |||
| @service_api_ns.marshal_with(build_site_model(service_api_ns)) | |||
| def get(self, app_model: App): | |||
| """Retrieve app site info.""" | |||
| """Retrieve app site info. | |||
| Returns the site configuration for the application including theme, icons, and text. | |||
| """ | |||
| site = db.session.query(Site).where(Site.app_id == app_model.id).first() | |||
| if not site: | |||
| raise Forbidden() | |||
| assert app_model.tenant | |||
| if app_model.tenant.status == TenantStatus.ARCHIVE: | |||
| raise Forbidden() | |||
| return site | |||
| api.add_resource(AppSiteApi, "/site") | |||
| @@ -2,12 +2,12 @@ import logging | |||
| from dateutil.parser import isoparse | |||
| from flask import request | |||
| from flask_restx import Resource, fields, marshal_with, reqparse | |||
| from flask_restx import Api, Namespace, Resource, fields, reqparse | |||
| from flask_restx.inputs import int_range | |||
| from sqlalchemy.orm import Session, sessionmaker | |||
| from werkzeug.exceptions import BadRequest, InternalServerError, NotFound | |||
| from controllers.service_api import api | |||
| from controllers.service_api import service_api_ns | |||
| from controllers.service_api.app.error import ( | |||
| CompletionRequestError, | |||
| NotWorkflowAppError, | |||
| @@ -28,7 +28,7 @@ from core.helper.trace_id_helper import get_external_trace_id | |||
| from core.model_runtime.errors.invoke import InvokeError | |||
| from core.workflow.entities.workflow_execution import WorkflowExecutionStatus | |||
| from extensions.ext_database import db | |||
| from fields.workflow_app_log_fields import workflow_app_log_pagination_fields | |||
| from fields.workflow_app_log_fields import build_workflow_app_log_pagination_model | |||
| from libs import helper | |||
| from libs.helper import TimestampField | |||
| from models.model import App, AppMode, EndUser | |||
| @@ -40,6 +40,34 @@ from services.workflow_app_service import WorkflowAppService | |||
| logger = logging.getLogger(__name__) | |||
| # Define parsers for workflow APIs | |||
| workflow_run_parser = reqparse.RequestParser() | |||
| workflow_run_parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") | |||
| workflow_run_parser.add_argument("files", type=list, required=False, location="json") | |||
| workflow_run_parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json") | |||
| workflow_log_parser = reqparse.RequestParser() | |||
| workflow_log_parser.add_argument("keyword", type=str, location="args") | |||
| workflow_log_parser.add_argument("status", type=str, choices=["succeeded", "failed", "stopped"], location="args") | |||
| workflow_log_parser.add_argument("created_at__before", type=str, location="args") | |||
| workflow_log_parser.add_argument("created_at__after", type=str, location="args") | |||
| workflow_log_parser.add_argument( | |||
| "created_by_end_user_session_id", | |||
| type=str, | |||
| location="args", | |||
| required=False, | |||
| default=None, | |||
| ) | |||
| workflow_log_parser.add_argument( | |||
| "created_by_account", | |||
| type=str, | |||
| location="args", | |||
| required=False, | |||
| default=None, | |||
| ) | |||
| workflow_log_parser.add_argument("page", type=int_range(1, 99999), default=1, location="args") | |||
| workflow_log_parser.add_argument("limit", type=int_range(1, 100), default=20, location="args") | |||
| workflow_run_fields = { | |||
| "id": fields.String, | |||
| "workflow_id": fields.String, | |||
| @@ -55,12 +83,29 @@ workflow_run_fields = { | |||
| } | |||
| def build_workflow_run_model(api_or_ns: Api | Namespace): | |||
| """Build the workflow run model for the API or Namespace.""" | |||
| return api_or_ns.model("WorkflowRun", workflow_run_fields) | |||
| @service_api_ns.route("/workflows/run/<string:workflow_run_id>") | |||
| class WorkflowRunDetailApi(Resource): | |||
| @service_api_ns.doc("get_workflow_run_detail") | |||
| @service_api_ns.doc(description="Get workflow run details") | |||
| @service_api_ns.doc(params={"workflow_run_id": "Workflow run ID"}) | |||
| @service_api_ns.doc( | |||
| responses={ | |||
| 200: "Workflow run details retrieved successfully", | |||
| 401: "Unauthorized - invalid API token", | |||
| 404: "Workflow run not found", | |||
| } | |||
| ) | |||
| @validate_app_token | |||
| @marshal_with(workflow_run_fields) | |||
| @service_api_ns.marshal_with(build_workflow_run_model(service_api_ns)) | |||
| def get(self, app_model: App, workflow_run_id: str): | |||
| """ | |||
| Get a workflow task running detail | |||
| """Get a workflow task running detail. | |||
| Returns detailed information about a specific workflow run. | |||
| """ | |||
| app_mode = AppMode.value_of(app_model.mode) | |||
| if app_mode not in [AppMode.WORKFLOW, AppMode.ADVANCED_CHAT]: | |||
| @@ -78,21 +123,33 @@ class WorkflowRunDetailApi(Resource): | |||
| return workflow_run | |||
| @service_api_ns.route("/workflows/run") | |||
| class WorkflowRunApi(Resource): | |||
| @service_api_ns.expect(workflow_run_parser) | |||
| @service_api_ns.doc("run_workflow") | |||
| @service_api_ns.doc(description="Execute a workflow") | |||
| @service_api_ns.doc( | |||
| responses={ | |||
| 200: "Workflow executed successfully", | |||
| 400: "Bad request - invalid parameters or workflow issues", | |||
| 401: "Unauthorized - invalid API token", | |||
| 404: "Workflow not found", | |||
| 429: "Rate limit exceeded", | |||
| 500: "Internal server error", | |||
| } | |||
| ) | |||
| @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True)) | |||
| def post(self, app_model: App, end_user: EndUser): | |||
| """ | |||
| Run workflow | |||
| """Execute a workflow. | |||
| Runs a workflow with the provided inputs and returns the results. | |||
| Supports both blocking and streaming response modes. | |||
| """ | |||
| app_mode = AppMode.value_of(app_model.mode) | |||
| if app_mode != AppMode.WORKFLOW: | |||
| raise NotWorkflowAppError() | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") | |||
| parser.add_argument("files", type=list, required=False, location="json") | |||
| parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json") | |||
| args = parser.parse_args() | |||
| args = workflow_run_parser.parse_args() | |||
| external_trace_id = get_external_trace_id(request) | |||
| if external_trace_id: | |||
| args["external_trace_id"] = external_trace_id | |||
| @@ -121,21 +178,33 @@ class WorkflowRunApi(Resource): | |||
| raise InternalServerError() | |||
| @service_api_ns.route("/workflows/<string:workflow_id>/run") | |||
| class WorkflowRunByIdApi(Resource): | |||
| @service_api_ns.expect(workflow_run_parser) | |||
| @service_api_ns.doc("run_workflow_by_id") | |||
| @service_api_ns.doc(description="Execute a specific workflow by ID") | |||
| @service_api_ns.doc(params={"workflow_id": "Workflow ID to execute"}) | |||
| @service_api_ns.doc( | |||
| responses={ | |||
| 200: "Workflow executed successfully", | |||
| 400: "Bad request - invalid parameters or workflow issues", | |||
| 401: "Unauthorized - invalid API token", | |||
| 404: "Workflow not found", | |||
| 429: "Rate limit exceeded", | |||
| 500: "Internal server error", | |||
| } | |||
| ) | |||
| @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True)) | |||
| def post(self, app_model: App, end_user: EndUser, workflow_id: str): | |||
| """ | |||
| Run specific workflow by ID | |||
| """Run specific workflow by ID. | |||
| Executes a specific workflow version identified by its ID. | |||
| """ | |||
| app_mode = AppMode.value_of(app_model.mode) | |||
| if app_mode != AppMode.WORKFLOW: | |||
| raise NotWorkflowAppError() | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") | |||
| parser.add_argument("files", type=list, required=False, location="json") | |||
| parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json") | |||
| args = parser.parse_args() | |||
| args = workflow_run_parser.parse_args() | |||
| # Add workflow_id to args for AppGenerateService | |||
| args["workflow_id"] = workflow_id | |||
| @@ -174,12 +243,21 @@ class WorkflowRunByIdApi(Resource): | |||
| raise InternalServerError() | |||
| @service_api_ns.route("/workflows/tasks/<string:task_id>/stop") | |||
| class WorkflowTaskStopApi(Resource): | |||
| @service_api_ns.doc("stop_workflow_task") | |||
| @service_api_ns.doc(description="Stop a running workflow task") | |||
| @service_api_ns.doc(params={"task_id": "Task ID to stop"}) | |||
| @service_api_ns.doc( | |||
| responses={ | |||
| 200: "Task stopped successfully", | |||
| 401: "Unauthorized - invalid API token", | |||
| 404: "Task not found", | |||
| } | |||
| ) | |||
| @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True)) | |||
| def post(self, app_model: App, end_user: EndUser, task_id: str): | |||
| """ | |||
| Stop workflow task | |||
| """ | |||
| """Stop a running workflow task.""" | |||
| app_mode = AppMode.value_of(app_model.mode) | |||
| if app_mode != AppMode.WORKFLOW: | |||
| raise NotWorkflowAppError() | |||
| @@ -189,35 +267,25 @@ class WorkflowTaskStopApi(Resource): | |||
| return {"result": "success"} | |||
| @service_api_ns.route("/workflows/logs") | |||
| class WorkflowAppLogApi(Resource): | |||
| @service_api_ns.expect(workflow_log_parser) | |||
| @service_api_ns.doc("get_workflow_logs") | |||
| @service_api_ns.doc(description="Get workflow execution logs") | |||
| @service_api_ns.doc( | |||
| responses={ | |||
| 200: "Logs retrieved successfully", | |||
| 401: "Unauthorized - invalid API token", | |||
| } | |||
| ) | |||
| @validate_app_token | |||
| @marshal_with(workflow_app_log_pagination_fields) | |||
| @service_api_ns.marshal_with(build_workflow_app_log_pagination_model(service_api_ns)) | |||
| def get(self, app_model: App): | |||
| """Get workflow app logs. | |||
| Returns paginated workflow execution logs with filtering options. | |||
| """ | |||
| Get workflow app logs | |||
| """ | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("keyword", type=str, location="args") | |||
| parser.add_argument("status", type=str, choices=["succeeded", "failed", "stopped"], location="args") | |||
| parser.add_argument("created_at__before", type=str, location="args") | |||
| parser.add_argument("created_at__after", type=str, location="args") | |||
| parser.add_argument( | |||
| "created_by_end_user_session_id", | |||
| type=str, | |||
| location="args", | |||
| required=False, | |||
| default=None, | |||
| ) | |||
| parser.add_argument( | |||
| "created_by_account", | |||
| type=str, | |||
| location="args", | |||
| required=False, | |||
| default=None, | |||
| ) | |||
| parser.add_argument("page", type=int_range(1, 99999), default=1, location="args") | |||
| parser.add_argument("limit", type=int_range(1, 100), default=20, location="args") | |||
| args = parser.parse_args() | |||
| args = workflow_log_parser.parse_args() | |||
| args.status = WorkflowExecutionStatus(args.status) if args.status else None | |||
| if args.created_at__before: | |||
| @@ -243,10 +311,3 @@ class WorkflowAppLogApi(Resource): | |||
| ) | |||
| return workflow_app_log_pagination | |||
| api.add_resource(WorkflowRunApi, "/workflows/run") | |||
| api.add_resource(WorkflowRunDetailApi, "/workflows/run/<string:workflow_run_id>") | |||
| api.add_resource(WorkflowRunByIdApi, "/workflows/<string:workflow_id>/run") | |||
| api.add_resource(WorkflowTaskStopApi, "/workflows/tasks/<string:task_id>/stop") | |||
| api.add_resource(WorkflowAppLogApi, "/workflows/logs") | |||
| @@ -1,11 +1,11 @@ | |||
| from typing import Literal | |||
| from flask import request | |||
| from flask_restx import marshal, marshal_with, reqparse | |||
| from flask_restx import marshal, reqparse | |||
| from werkzeug.exceptions import Forbidden, NotFound | |||
| import services.dataset_service | |||
| from controllers.service_api import api | |||
| from controllers.service_api import service_api_ns | |||
| from controllers.service_api.dataset.error import DatasetInUseError, DatasetNameDuplicateError, InvalidActionError | |||
| from controllers.service_api.wraps import ( | |||
| DatasetApiResource, | |||
| @@ -16,7 +16,7 @@ from core.model_runtime.entities.model_entities import ModelType | |||
| from core.plugin.entities.plugin import ModelProviderID | |||
| from core.provider_manager import ProviderManager | |||
| from fields.dataset_fields import dataset_detail_fields | |||
| from fields.tag_fields import tag_fields | |||
| from fields.tag_fields import build_dataset_tag_fields | |||
| from libs.login import current_user | |||
| from models.dataset import Dataset, DatasetPermissionEnum | |||
| from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService | |||
| @@ -36,12 +36,171 @@ def _validate_description_length(description): | |||
| return description | |||
| # Define parsers for dataset operations | |||
| dataset_create_parser = reqparse.RequestParser() | |||
| dataset_create_parser.add_argument( | |||
| "name", | |||
| nullable=False, | |||
| required=True, | |||
| help="type is required. Name must be between 1 to 40 characters.", | |||
| type=_validate_name, | |||
| ) | |||
| dataset_create_parser.add_argument( | |||
| "description", | |||
| type=_validate_description_length, | |||
| nullable=True, | |||
| required=False, | |||
| default="", | |||
| ) | |||
| dataset_create_parser.add_argument( | |||
| "indexing_technique", | |||
| type=str, | |||
| location="json", | |||
| choices=Dataset.INDEXING_TECHNIQUE_LIST, | |||
| help="Invalid indexing technique.", | |||
| ) | |||
| dataset_create_parser.add_argument( | |||
| "permission", | |||
| type=str, | |||
| location="json", | |||
| choices=(DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM), | |||
| help="Invalid permission.", | |||
| required=False, | |||
| nullable=False, | |||
| ) | |||
| dataset_create_parser.add_argument( | |||
| "external_knowledge_api_id", | |||
| type=str, | |||
| nullable=True, | |||
| required=False, | |||
| default="_validate_name", | |||
| ) | |||
| dataset_create_parser.add_argument( | |||
| "provider", | |||
| type=str, | |||
| nullable=True, | |||
| required=False, | |||
| default="vendor", | |||
| ) | |||
| dataset_create_parser.add_argument( | |||
| "external_knowledge_id", | |||
| type=str, | |||
| nullable=True, | |||
| required=False, | |||
| ) | |||
| dataset_create_parser.add_argument("retrieval_model", type=dict, required=False, nullable=True, location="json") | |||
| dataset_create_parser.add_argument("embedding_model", type=str, required=False, nullable=True, location="json") | |||
| dataset_create_parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json") | |||
| dataset_update_parser = reqparse.RequestParser() | |||
| dataset_update_parser.add_argument( | |||
| "name", | |||
| nullable=False, | |||
| help="type is required. Name must be between 1 to 40 characters.", | |||
| type=_validate_name, | |||
| ) | |||
| dataset_update_parser.add_argument( | |||
| "description", location="json", store_missing=False, type=_validate_description_length | |||
| ) | |||
| dataset_update_parser.add_argument( | |||
| "indexing_technique", | |||
| type=str, | |||
| location="json", | |||
| choices=Dataset.INDEXING_TECHNIQUE_LIST, | |||
| nullable=True, | |||
| help="Invalid indexing technique.", | |||
| ) | |||
| dataset_update_parser.add_argument( | |||
| "permission", | |||
| type=str, | |||
| location="json", | |||
| choices=(DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM), | |||
| help="Invalid permission.", | |||
| ) | |||
| dataset_update_parser.add_argument("embedding_model", type=str, location="json", help="Invalid embedding model.") | |||
| dataset_update_parser.add_argument( | |||
| "embedding_model_provider", type=str, location="json", help="Invalid embedding model provider." | |||
| ) | |||
| dataset_update_parser.add_argument("retrieval_model", type=dict, location="json", help="Invalid retrieval model.") | |||
| dataset_update_parser.add_argument("partial_member_list", type=list, location="json", help="Invalid parent user list.") | |||
| dataset_update_parser.add_argument( | |||
| "external_retrieval_model", | |||
| type=dict, | |||
| required=False, | |||
| nullable=True, | |||
| location="json", | |||
| help="Invalid external retrieval model.", | |||
| ) | |||
| dataset_update_parser.add_argument( | |||
| "external_knowledge_id", | |||
| type=str, | |||
| required=False, | |||
| nullable=True, | |||
| location="json", | |||
| help="Invalid external knowledge id.", | |||
| ) | |||
| dataset_update_parser.add_argument( | |||
| "external_knowledge_api_id", | |||
| type=str, | |||
| required=False, | |||
| nullable=True, | |||
| location="json", | |||
| help="Invalid external knowledge api id.", | |||
| ) | |||
| tag_create_parser = reqparse.RequestParser() | |||
| tag_create_parser.add_argument( | |||
| "name", | |||
| nullable=False, | |||
| required=True, | |||
| help="Name must be between 1 to 50 characters.", | |||
| type=lambda x: x | |||
| if x and 1 <= len(x) <= 50 | |||
| else (_ for _ in ()).throw(ValueError("Name must be between 1 to 50 characters.")), | |||
| ) | |||
| tag_update_parser = reqparse.RequestParser() | |||
| tag_update_parser.add_argument( | |||
| "name", | |||
| nullable=False, | |||
| required=True, | |||
| help="Name must be between 1 to 50 characters.", | |||
| type=lambda x: x | |||
| if x and 1 <= len(x) <= 50 | |||
| else (_ for _ in ()).throw(ValueError("Name must be between 1 to 50 characters.")), | |||
| ) | |||
| tag_update_parser.add_argument("tag_id", nullable=False, required=True, help="Id of a tag.", type=str) | |||
| tag_delete_parser = reqparse.RequestParser() | |||
| tag_delete_parser.add_argument("tag_id", nullable=False, required=True, help="Id of a tag.", type=str) | |||
| tag_binding_parser = reqparse.RequestParser() | |||
| tag_binding_parser.add_argument( | |||
| "tag_ids", type=list, nullable=False, required=True, location="json", help="Tag IDs is required." | |||
| ) | |||
| tag_binding_parser.add_argument( | |||
| "target_id", type=str, nullable=False, required=True, location="json", help="Target Dataset ID is required." | |||
| ) | |||
| tag_unbinding_parser = reqparse.RequestParser() | |||
| tag_unbinding_parser.add_argument("tag_id", type=str, nullable=False, required=True, help="Tag ID is required.") | |||
| tag_unbinding_parser.add_argument("target_id", type=str, nullable=False, required=True, help="Target ID is required.") | |||
| @service_api_ns.route("/datasets") | |||
| class DatasetListApi(DatasetApiResource): | |||
| """Resource for datasets.""" | |||
| @service_api_ns.doc("list_datasets") | |||
| @service_api_ns.doc(description="List all datasets") | |||
| @service_api_ns.doc( | |||
| responses={ | |||
| 200: "Datasets retrieved successfully", | |||
| 401: "Unauthorized - invalid API token", | |||
| } | |||
| ) | |||
| def get(self, tenant_id): | |||
| """Resource for getting datasets.""" | |||
| page = request.args.get("page", default=1, type=int) | |||
| limit = request.args.get("limit", default=20, type=int) | |||
| # provider = request.args.get("provider", default="vendor") | |||
| @@ -76,65 +235,20 @@ class DatasetListApi(DatasetApiResource): | |||
| response = {"data": data, "has_more": len(datasets) == limit, "limit": limit, "total": total, "page": page} | |||
| return response, 200 | |||
| @service_api_ns.expect(dataset_create_parser) | |||
| @service_api_ns.doc("create_dataset") | |||
| @service_api_ns.doc(description="Create a new dataset") | |||
| @service_api_ns.doc( | |||
| responses={ | |||
| 200: "Dataset created successfully", | |||
| 401: "Unauthorized - invalid API token", | |||
| 400: "Bad request - invalid parameters", | |||
| } | |||
| ) | |||
| @cloud_edition_billing_rate_limit_check("knowledge", "dataset") | |||
| def post(self, tenant_id): | |||
| """Resource for creating datasets.""" | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument( | |||
| "name", | |||
| nullable=False, | |||
| required=True, | |||
| help="type is required. Name must be between 1 to 40 characters.", | |||
| type=_validate_name, | |||
| ) | |||
| parser.add_argument( | |||
| "description", | |||
| type=_validate_description_length, | |||
| nullable=True, | |||
| required=False, | |||
| default="", | |||
| ) | |||
| parser.add_argument( | |||
| "indexing_technique", | |||
| type=str, | |||
| location="json", | |||
| choices=Dataset.INDEXING_TECHNIQUE_LIST, | |||
| help="Invalid indexing technique.", | |||
| ) | |||
| parser.add_argument( | |||
| "permission", | |||
| type=str, | |||
| location="json", | |||
| choices=(DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM), | |||
| help="Invalid permission.", | |||
| required=False, | |||
| nullable=False, | |||
| ) | |||
| parser.add_argument( | |||
| "external_knowledge_api_id", | |||
| type=str, | |||
| nullable=True, | |||
| required=False, | |||
| default="_validate_name", | |||
| ) | |||
| parser.add_argument( | |||
| "provider", | |||
| type=str, | |||
| nullable=True, | |||
| required=False, | |||
| default="vendor", | |||
| ) | |||
| parser.add_argument( | |||
| "external_knowledge_id", | |||
| type=str, | |||
| nullable=True, | |||
| required=False, | |||
| ) | |||
| parser.add_argument("retrieval_model", type=dict, required=False, nullable=True, location="json") | |||
| parser.add_argument("embedding_model", type=str, required=False, nullable=True, location="json") | |||
| parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json") | |||
| args = parser.parse_args() | |||
| args = dataset_create_parser.parse_args() | |||
| if args.get("embedding_model_provider"): | |||
| DatasetService.check_embedding_model_setting( | |||
| @@ -174,9 +288,21 @@ class DatasetListApi(DatasetApiResource): | |||
| return marshal(dataset, dataset_detail_fields), 200 | |||
| @service_api_ns.route("/datasets/<uuid:dataset_id>") | |||
| class DatasetApi(DatasetApiResource): | |||
| """Resource for dataset.""" | |||
| @service_api_ns.doc("get_dataset") | |||
| @service_api_ns.doc(description="Get a specific dataset by ID") | |||
| @service_api_ns.doc(params={"dataset_id": "Dataset ID"}) | |||
| @service_api_ns.doc( | |||
| responses={ | |||
| 200: "Dataset retrieved successfully", | |||
| 401: "Unauthorized - invalid API token", | |||
| 403: "Forbidden - insufficient permissions", | |||
| 404: "Dataset not found", | |||
| } | |||
| ) | |||
| def get(self, _, dataset_id): | |||
| dataset_id_str = str(dataset_id) | |||
| dataset = DatasetService.get_dataset(dataset_id_str) | |||
| @@ -216,6 +342,18 @@ class DatasetApi(DatasetApiResource): | |||
| return data, 200 | |||
| @service_api_ns.expect(dataset_update_parser) | |||
| @service_api_ns.doc("update_dataset") | |||
| @service_api_ns.doc(description="Update an existing dataset") | |||
| @service_api_ns.doc(params={"dataset_id": "Dataset ID"}) | |||
| @service_api_ns.doc( | |||
| responses={ | |||
| 200: "Dataset updated successfully", | |||
| 401: "Unauthorized - invalid API token", | |||
| 403: "Forbidden - insufficient permissions", | |||
| 404: "Dataset not found", | |||
| } | |||
| ) | |||
| @cloud_edition_billing_rate_limit_check("knowledge", "dataset") | |||
| def patch(self, _, dataset_id): | |||
| dataset_id_str = str(dataset_id) | |||
| @@ -223,63 +361,7 @@ class DatasetApi(DatasetApiResource): | |||
| if dataset is None: | |||
| raise NotFound("Dataset not found.") | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument( | |||
| "name", | |||
| nullable=False, | |||
| help="type is required. Name must be between 1 to 40 characters.", | |||
| type=_validate_name, | |||
| ) | |||
| parser.add_argument("description", location="json", store_missing=False, type=_validate_description_length) | |||
| parser.add_argument( | |||
| "indexing_technique", | |||
| type=str, | |||
| location="json", | |||
| choices=Dataset.INDEXING_TECHNIQUE_LIST, | |||
| nullable=True, | |||
| help="Invalid indexing technique.", | |||
| ) | |||
| parser.add_argument( | |||
| "permission", | |||
| type=str, | |||
| location="json", | |||
| choices=(DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM), | |||
| help="Invalid permission.", | |||
| ) | |||
| parser.add_argument("embedding_model", type=str, location="json", help="Invalid embedding model.") | |||
| parser.add_argument( | |||
| "embedding_model_provider", type=str, location="json", help="Invalid embedding model provider." | |||
| ) | |||
| parser.add_argument("retrieval_model", type=dict, location="json", help="Invalid retrieval model.") | |||
| parser.add_argument("partial_member_list", type=list, location="json", help="Invalid parent user list.") | |||
| parser.add_argument( | |||
| "external_retrieval_model", | |||
| type=dict, | |||
| required=False, | |||
| nullable=True, | |||
| location="json", | |||
| help="Invalid external retrieval model.", | |||
| ) | |||
| parser.add_argument( | |||
| "external_knowledge_id", | |||
| type=str, | |||
| required=False, | |||
| nullable=True, | |||
| location="json", | |||
| help="Invalid external knowledge id.", | |||
| ) | |||
| parser.add_argument( | |||
| "external_knowledge_api_id", | |||
| type=str, | |||
| required=False, | |||
| nullable=True, | |||
| location="json", | |||
| help="Invalid external knowledge api id.", | |||
| ) | |||
| args = parser.parse_args() | |||
| args = dataset_update_parser.parse_args() | |||
| data = request.get_json() | |||
| # check embedding model setting | |||
| @@ -327,6 +409,17 @@ class DatasetApi(DatasetApiResource): | |||
| return result_data, 200 | |||
| @service_api_ns.doc("delete_dataset") | |||
| @service_api_ns.doc(description="Delete a dataset") | |||
| @service_api_ns.doc(params={"dataset_id": "Dataset ID"}) | |||
| @service_api_ns.doc( | |||
| responses={ | |||
| 204: "Dataset deleted successfully", | |||
| 401: "Unauthorized - invalid API token", | |||
| 404: "Dataset not found", | |||
| 409: "Conflict - dataset is in use", | |||
| } | |||
| ) | |||
| @cloud_edition_billing_rate_limit_check("knowledge", "dataset") | |||
| def delete(self, _, dataset_id): | |||
| """ | |||
| @@ -357,9 +450,27 @@ class DatasetApi(DatasetApiResource): | |||
| raise DatasetInUseError() | |||
| @service_api_ns.route("/datasets/<uuid:dataset_id>/documents/status/<string:action>") | |||
| class DocumentStatusApi(DatasetApiResource): | |||
| """Resource for batch document status operations.""" | |||
| @service_api_ns.doc("update_document_status") | |||
| @service_api_ns.doc(description="Batch update document status") | |||
| @service_api_ns.doc( | |||
| params={ | |||
| "dataset_id": "Dataset ID", | |||
| "action": "Action to perform: 'enable', 'disable', 'archive', or 'un_archive'", | |||
| } | |||
| ) | |||
| @service_api_ns.doc( | |||
| responses={ | |||
| 200: "Document status updated successfully", | |||
| 401: "Unauthorized - invalid API token", | |||
| 403: "Forbidden - insufficient permissions", | |||
| 404: "Dataset not found", | |||
| 400: "Bad request - invalid action", | |||
| } | |||
| ) | |||
| def patch(self, tenant_id, dataset_id, action: Literal["enable", "disable", "archive", "un_archive"]): | |||
| """ | |||
| Batch update document status. | |||
| @@ -407,53 +518,65 @@ class DocumentStatusApi(DatasetApiResource): | |||
| return {"result": "success"}, 200 | |||
| @service_api_ns.route("/datasets/tags") | |||
| class DatasetTagsApi(DatasetApiResource): | |||
| @service_api_ns.doc("list_dataset_tags") | |||
| @service_api_ns.doc(description="Get all knowledge type tags") | |||
| @service_api_ns.doc( | |||
| responses={ | |||
| 200: "Tags retrieved successfully", | |||
| 401: "Unauthorized - invalid API token", | |||
| } | |||
| ) | |||
| @validate_dataset_token | |||
| @marshal_with(tag_fields) | |||
| @service_api_ns.marshal_with(build_dataset_tag_fields(service_api_ns)) | |||
| def get(self, _, dataset_id): | |||
| """Get all knowledge type tags.""" | |||
| tags = TagService.get_tags("knowledge", current_user.current_tenant_id) | |||
| return tags, 200 | |||
| @service_api_ns.expect(tag_create_parser) | |||
| @service_api_ns.doc("create_dataset_tag") | |||
| @service_api_ns.doc(description="Add a knowledge type tag") | |||
| @service_api_ns.doc( | |||
| responses={ | |||
| 200: "Tag created successfully", | |||
| 401: "Unauthorized - invalid API token", | |||
| 403: "Forbidden - insufficient permissions", | |||
| } | |||
| ) | |||
| @service_api_ns.marshal_with(build_dataset_tag_fields(service_api_ns)) | |||
| @validate_dataset_token | |||
| def post(self, _, dataset_id): | |||
| """Add a knowledge type tag.""" | |||
| if not (current_user.is_editor or current_user.is_dataset_editor): | |||
| raise Forbidden() | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument( | |||
| "name", | |||
| nullable=False, | |||
| required=True, | |||
| help="Name must be between 1 to 50 characters.", | |||
| type=DatasetTagsApi._validate_tag_name, | |||
| ) | |||
| args = parser.parse_args() | |||
| args = tag_create_parser.parse_args() | |||
| args["type"] = "knowledge" | |||
| tag = TagService.save_tags(args) | |||
| response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0} | |||
| return response, 200 | |||
| @service_api_ns.expect(tag_update_parser) | |||
| @service_api_ns.doc("update_dataset_tag") | |||
| @service_api_ns.doc(description="Update a knowledge type tag") | |||
| @service_api_ns.doc( | |||
| responses={ | |||
| 200: "Tag updated successfully", | |||
| 401: "Unauthorized - invalid API token", | |||
| 403: "Forbidden - insufficient permissions", | |||
| } | |||
| ) | |||
| @service_api_ns.marshal_with(build_dataset_tag_fields(service_api_ns)) | |||
| @validate_dataset_token | |||
| def patch(self, _, dataset_id): | |||
| if not (current_user.is_editor or current_user.is_dataset_editor): | |||
| raise Forbidden() | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument( | |||
| "name", | |||
| nullable=False, | |||
| required=True, | |||
| help="Name must be between 1 to 50 characters.", | |||
| type=DatasetTagsApi._validate_tag_name, | |||
| ) | |||
| parser.add_argument("tag_id", nullable=False, required=True, help="Id of a tag.", type=str) | |||
| args = parser.parse_args() | |||
| args = tag_update_parser.parse_args() | |||
| args["type"] = "knowledge" | |||
| tag = TagService.update_tags(args, args.get("tag_id")) | |||
| @@ -463,66 +586,88 @@ class DatasetTagsApi(DatasetApiResource): | |||
| return response, 200 | |||
| @service_api_ns.expect(tag_delete_parser) | |||
| @service_api_ns.doc("delete_dataset_tag") | |||
| @service_api_ns.doc(description="Delete a knowledge type tag") | |||
| @service_api_ns.doc( | |||
| responses={ | |||
| 204: "Tag deleted successfully", | |||
| 401: "Unauthorized - invalid API token", | |||
| 403: "Forbidden - insufficient permissions", | |||
| } | |||
| ) | |||
| @validate_dataset_token | |||
| def delete(self, _, dataset_id): | |||
| """Delete a knowledge type tag.""" | |||
| if not current_user.is_editor: | |||
| raise Forbidden() | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("tag_id", nullable=False, required=True, help="Id of a tag.", type=str) | |||
| args = parser.parse_args() | |||
| args = tag_delete_parser.parse_args() | |||
| TagService.delete_tag(args.get("tag_id")) | |||
| return 204 | |||
| @staticmethod | |||
| def _validate_tag_name(name): | |||
| if not name or len(name) < 1 or len(name) > 50: | |||
| raise ValueError("Name must be between 1 to 50 characters.") | |||
| return name | |||
| @service_api_ns.route("/datasets/tags/binding") | |||
| class DatasetTagBindingApi(DatasetApiResource): | |||
| @service_api_ns.expect(tag_binding_parser) | |||
| @service_api_ns.doc("bind_dataset_tags") | |||
| @service_api_ns.doc(description="Bind tags to a dataset") | |||
| @service_api_ns.doc( | |||
| responses={ | |||
| 204: "Tags bound successfully", | |||
| 401: "Unauthorized - invalid API token", | |||
| 403: "Forbidden - insufficient permissions", | |||
| } | |||
| ) | |||
| @validate_dataset_token | |||
| def post(self, _, dataset_id): | |||
| # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator | |||
| if not (current_user.is_editor or current_user.is_dataset_editor): | |||
| raise Forbidden() | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument( | |||
| "tag_ids", type=list, nullable=False, required=True, location="json", help="Tag IDs is required." | |||
| ) | |||
| parser.add_argument( | |||
| "target_id", type=str, nullable=False, required=True, location="json", help="Target Dataset ID is required." | |||
| ) | |||
| args = parser.parse_args() | |||
| args = tag_binding_parser.parse_args() | |||
| args["type"] = "knowledge" | |||
| TagService.save_tag_binding(args) | |||
| return 204 | |||
| @service_api_ns.route("/datasets/tags/unbinding") | |||
| class DatasetTagUnbindingApi(DatasetApiResource): | |||
| @service_api_ns.expect(tag_unbinding_parser) | |||
| @service_api_ns.doc("unbind_dataset_tag") | |||
| @service_api_ns.doc(description="Unbind a tag from a dataset") | |||
| @service_api_ns.doc( | |||
| responses={ | |||
| 204: "Tag unbound successfully", | |||
| 401: "Unauthorized - invalid API token", | |||
| 403: "Forbidden - insufficient permissions", | |||
| } | |||
| ) | |||
| @validate_dataset_token | |||
| def post(self, _, dataset_id): | |||
| # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator | |||
| if not (current_user.is_editor or current_user.is_dataset_editor): | |||
| raise Forbidden() | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("tag_id", type=str, nullable=False, required=True, help="Tag ID is required.") | |||
| parser.add_argument("target_id", type=str, nullable=False, required=True, help="Target ID is required.") | |||
| args = parser.parse_args() | |||
| args = tag_unbinding_parser.parse_args() | |||
| args["type"] = "knowledge" | |||
| TagService.delete_tag_binding(args) | |||
| return 204 | |||
| @service_api_ns.route("/datasets/<uuid:dataset_id>/tags") | |||
| class DatasetTagsBindingStatusApi(DatasetApiResource): | |||
| @service_api_ns.doc("get_dataset_tags_binding_status") | |||
| @service_api_ns.doc(description="Get tags bound to a specific dataset") | |||
| @service_api_ns.doc(params={"dataset_id": "Dataset ID"}) | |||
| @service_api_ns.doc( | |||
| responses={ | |||
| 200: "Tags retrieved successfully", | |||
| 401: "Unauthorized - invalid API token", | |||
| } | |||
| ) | |||
| @validate_dataset_token | |||
| def get(self, _, *args, **kwargs): | |||
| """Get all knowledge type tags.""" | |||
| @@ -531,12 +676,3 @@ class DatasetTagsBindingStatusApi(DatasetApiResource): | |||
| tags_list = [{"id": tag.id, "name": tag.name} for tag in tags] | |||
| response = {"data": tags_list, "total": len(tags)} | |||
| return response, 200 | |||
| api.add_resource(DatasetListApi, "/datasets") | |||
| api.add_resource(DatasetApi, "/datasets/<uuid:dataset_id>") | |||
| api.add_resource(DocumentStatusApi, "/datasets/<uuid:dataset_id>/documents/status/<string:action>") | |||
| api.add_resource(DatasetTagsApi, "/datasets/tags") | |||
| api.add_resource(DatasetTagBindingApi, "/datasets/tags/binding") | |||
| api.add_resource(DatasetTagUnbindingApi, "/datasets/tags/unbinding") | |||
| api.add_resource(DatasetTagsBindingStatusApi, "/datasets/<uuid:dataset_id>/tags") | |||
| @@ -13,7 +13,7 @@ from controllers.common.errors import ( | |||
| TooManyFilesError, | |||
| UnsupportedFileTypeError, | |||
| ) | |||
| from controllers.service_api import api | |||
| from controllers.service_api import service_api_ns | |||
| from controllers.service_api.app.error import ProviderNotInitializeError | |||
| from controllers.service_api.dataset.error import ( | |||
| ArchivedDocumentImmutableError, | |||
| @@ -34,32 +34,64 @@ from services.dataset_service import DatasetService, DocumentService | |||
| from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig | |||
| from services.file_service import FileService | |||
| # Define parsers for document operations | |||
| document_text_create_parser = reqparse.RequestParser() | |||
| document_text_create_parser.add_argument("name", type=str, required=True, nullable=False, location="json") | |||
| document_text_create_parser.add_argument("text", type=str, required=True, nullable=False, location="json") | |||
| document_text_create_parser.add_argument("process_rule", type=dict, required=False, nullable=True, location="json") | |||
| document_text_create_parser.add_argument("original_document_id", type=str, required=False, location="json") | |||
| document_text_create_parser.add_argument( | |||
| "doc_form", type=str, default="text_model", required=False, nullable=False, location="json" | |||
| ) | |||
| document_text_create_parser.add_argument( | |||
| "doc_language", type=str, default="English", required=False, nullable=False, location="json" | |||
| ) | |||
| document_text_create_parser.add_argument( | |||
| "indexing_technique", type=str, choices=Dataset.INDEXING_TECHNIQUE_LIST, nullable=False, location="json" | |||
| ) | |||
| document_text_create_parser.add_argument("retrieval_model", type=dict, required=False, nullable=True, location="json") | |||
| document_text_create_parser.add_argument("embedding_model", type=str, required=False, nullable=True, location="json") | |||
| document_text_create_parser.add_argument( | |||
| "embedding_model_provider", type=str, required=False, nullable=True, location="json" | |||
| ) | |||
| document_text_update_parser = reqparse.RequestParser() | |||
| document_text_update_parser.add_argument("name", type=str, required=False, nullable=True, location="json") | |||
| document_text_update_parser.add_argument("text", type=str, required=False, nullable=True, location="json") | |||
| document_text_update_parser.add_argument("process_rule", type=dict, required=False, nullable=True, location="json") | |||
| document_text_update_parser.add_argument( | |||
| "doc_form", type=str, default="text_model", required=False, nullable=False, location="json" | |||
| ) | |||
| document_text_update_parser.add_argument( | |||
| "doc_language", type=str, default="English", required=False, nullable=False, location="json" | |||
| ) | |||
| document_text_update_parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json") | |||
| @service_api_ns.route( | |||
| "/datasets/<uuid:dataset_id>/document/create_by_text", | |||
| "/datasets/<uuid:dataset_id>/document/create-by-text", | |||
| ) | |||
| class DocumentAddByTextApi(DatasetApiResource): | |||
| """Resource for documents.""" | |||
| @service_api_ns.expect(document_text_create_parser) | |||
| @service_api_ns.doc("create_document_by_text") | |||
| @service_api_ns.doc(description="Create a new document by providing text content") | |||
| @service_api_ns.doc(params={"dataset_id": "Dataset ID"}) | |||
| @service_api_ns.doc( | |||
| responses={ | |||
| 200: "Document created successfully", | |||
| 401: "Unauthorized - invalid API token", | |||
| 400: "Bad request - invalid parameters", | |||
| } | |||
| ) | |||
| @cloud_edition_billing_resource_check("vector_space", "dataset") | |||
| @cloud_edition_billing_resource_check("documents", "dataset") | |||
| @cloud_edition_billing_rate_limit_check("knowledge", "dataset") | |||
| def post(self, tenant_id, dataset_id): | |||
| """Create document by text.""" | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("name", type=str, required=True, nullable=False, location="json") | |||
| parser.add_argument("text", type=str, required=True, nullable=False, location="json") | |||
| parser.add_argument("process_rule", type=dict, required=False, nullable=True, location="json") | |||
| parser.add_argument("original_document_id", type=str, required=False, location="json") | |||
| parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json") | |||
| parser.add_argument( | |||
| "doc_language", type=str, default="English", required=False, nullable=False, location="json" | |||
| ) | |||
| parser.add_argument( | |||
| "indexing_technique", type=str, choices=Dataset.INDEXING_TECHNIQUE_LIST, nullable=False, location="json" | |||
| ) | |||
| parser.add_argument("retrieval_model", type=dict, required=False, nullable=True, location="json") | |||
| parser.add_argument("embedding_model", type=str, required=False, nullable=True, location="json") | |||
| parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json") | |||
| args = parser.parse_args() | |||
| args = document_text_create_parser.parse_args() | |||
| dataset_id = str(dataset_id) | |||
| tenant_id = str(tenant_id) | |||
| @@ -117,23 +149,29 @@ class DocumentAddByTextApi(DatasetApiResource): | |||
| return documents_and_batch_fields, 200 | |||
| @service_api_ns.route( | |||
| "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/update_by_text", | |||
| "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/update-by-text", | |||
| ) | |||
| class DocumentUpdateByTextApi(DatasetApiResource): | |||
| """Resource for update documents.""" | |||
| @service_api_ns.expect(document_text_update_parser) | |||
| @service_api_ns.doc("update_document_by_text") | |||
| @service_api_ns.doc(description="Update an existing document by providing text content") | |||
| @service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"}) | |||
| @service_api_ns.doc( | |||
| responses={ | |||
| 200: "Document updated successfully", | |||
| 401: "Unauthorized - invalid API token", | |||
| 404: "Document not found", | |||
| } | |||
| ) | |||
| @cloud_edition_billing_resource_check("vector_space", "dataset") | |||
| @cloud_edition_billing_rate_limit_check("knowledge", "dataset") | |||
| def post(self, tenant_id, dataset_id, document_id): | |||
| """Update document by text.""" | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("name", type=str, required=False, nullable=True, location="json") | |||
| parser.add_argument("text", type=str, required=False, nullable=True, location="json") | |||
| parser.add_argument("process_rule", type=dict, required=False, nullable=True, location="json") | |||
| parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json") | |||
| parser.add_argument( | |||
| "doc_language", type=str, default="English", required=False, nullable=False, location="json" | |||
| ) | |||
| parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json") | |||
| args = parser.parse_args() | |||
| args = document_text_update_parser.parse_args() | |||
| dataset_id = str(dataset_id) | |||
| tenant_id = str(tenant_id) | |||
| dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() | |||
| @@ -187,9 +225,23 @@ class DocumentUpdateByTextApi(DatasetApiResource): | |||
| return documents_and_batch_fields, 200 | |||
| @service_api_ns.route( | |||
| "/datasets/<uuid:dataset_id>/document/create_by_file", | |||
| "/datasets/<uuid:dataset_id>/document/create-by-file", | |||
| ) | |||
| class DocumentAddByFileApi(DatasetApiResource): | |||
| """Resource for documents.""" | |||
| @service_api_ns.doc("create_document_by_file") | |||
| @service_api_ns.doc(description="Create a new document by uploading a file") | |||
| @service_api_ns.doc(params={"dataset_id": "Dataset ID"}) | |||
| @service_api_ns.doc( | |||
| responses={ | |||
| 200: "Document created successfully", | |||
| 401: "Unauthorized - invalid API token", | |||
| 400: "Bad request - invalid file or parameters", | |||
| } | |||
| ) | |||
| @cloud_edition_billing_resource_check("vector_space", "dataset") | |||
| @cloud_edition_billing_resource_check("documents", "dataset") | |||
| @cloud_edition_billing_rate_limit_check("knowledge", "dataset") | |||
| @@ -281,9 +333,23 @@ class DocumentAddByFileApi(DatasetApiResource): | |||
| return documents_and_batch_fields, 200 | |||
| @service_api_ns.route( | |||
| "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/update_by_file", | |||
| "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/update-by-file", | |||
| ) | |||
| class DocumentUpdateByFileApi(DatasetApiResource): | |||
| """Resource for update documents.""" | |||
| @service_api_ns.doc("update_document_by_file") | |||
| @service_api_ns.doc(description="Update an existing document by uploading a file") | |||
| @service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"}) | |||
| @service_api_ns.doc( | |||
| responses={ | |||
| 200: "Document updated successfully", | |||
| 401: "Unauthorized - invalid API token", | |||
| 404: "Document not found", | |||
| } | |||
| ) | |||
| @cloud_edition_billing_resource_check("vector_space", "dataset") | |||
| @cloud_edition_billing_rate_limit_check("knowledge", "dataset") | |||
| def post(self, tenant_id, dataset_id, document_id): | |||
| @@ -358,7 +424,18 @@ class DocumentUpdateByFileApi(DatasetApiResource): | |||
| return documents_and_batch_fields, 200 | |||
| @service_api_ns.route("/datasets/<uuid:dataset_id>/documents") | |||
| class DocumentListApi(DatasetApiResource): | |||
| @service_api_ns.doc("list_documents") | |||
| @service_api_ns.doc(description="List all documents in a dataset") | |||
| @service_api_ns.doc(params={"dataset_id": "Dataset ID"}) | |||
| @service_api_ns.doc( | |||
| responses={ | |||
| 200: "Documents retrieved successfully", | |||
| 401: "Unauthorized - invalid API token", | |||
| 404: "Dataset not found", | |||
| } | |||
| ) | |||
| def get(self, tenant_id, dataset_id): | |||
| dataset_id = str(dataset_id) | |||
| tenant_id = str(tenant_id) | |||
| @@ -391,7 +468,18 @@ class DocumentListApi(DatasetApiResource): | |||
| return response | |||
| @service_api_ns.route("/datasets/<uuid:dataset_id>/documents/<string:batch>/indexing-status") | |||
| class DocumentIndexingStatusApi(DatasetApiResource): | |||
| @service_api_ns.doc("get_document_indexing_status") | |||
| @service_api_ns.doc(description="Get indexing status for documents in a batch") | |||
| @service_api_ns.doc(params={"dataset_id": "Dataset ID", "batch": "Batch ID"}) | |||
| @service_api_ns.doc( | |||
| responses={ | |||
| 200: "Indexing status retrieved successfully", | |||
| 401: "Unauthorized - invalid API token", | |||
| 404: "Dataset or documents not found", | |||
| } | |||
| ) | |||
| def get(self, tenant_id, dataset_id, batch): | |||
| dataset_id = str(dataset_id) | |||
| batch = str(batch) | |||
| @@ -440,9 +528,21 @@ class DocumentIndexingStatusApi(DatasetApiResource): | |||
| return data | |||
| @service_api_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>") | |||
| class DocumentApi(DatasetApiResource): | |||
| METADATA_CHOICES = {"all", "only", "without"} | |||
| @service_api_ns.doc("get_document") | |||
| @service_api_ns.doc(description="Get a specific document by ID") | |||
| @service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"}) | |||
| @service_api_ns.doc( | |||
| responses={ | |||
| 200: "Document retrieved successfully", | |||
| 401: "Unauthorized - invalid API token", | |||
| 403: "Forbidden - insufficient permissions", | |||
| 404: "Document not found", | |||
| } | |||
| ) | |||
| def get(self, tenant_id, dataset_id, document_id): | |||
| dataset_id = str(dataset_id) | |||
| document_id = str(document_id) | |||
| @@ -534,6 +634,17 @@ class DocumentApi(DatasetApiResource): | |||
| return response | |||
| @service_api_ns.doc("delete_document") | |||
| @service_api_ns.doc(description="Delete a document") | |||
| @service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"}) | |||
| @service_api_ns.doc( | |||
| responses={ | |||
| 204: "Document deleted successfully", | |||
| 401: "Unauthorized - invalid API token", | |||
| 403: "Forbidden - document is archived", | |||
| 404: "Document not found", | |||
| } | |||
| ) | |||
| @cloud_edition_billing_rate_limit_check("knowledge", "dataset") | |||
| def delete(self, tenant_id, dataset_id, document_id): | |||
| """Delete document.""" | |||
| @@ -564,28 +675,3 @@ class DocumentApi(DatasetApiResource): | |||
| raise DocumentIndexingError("Cannot delete document during indexing.") | |||
| return 204 | |||
| api.add_resource( | |||
| DocumentAddByTextApi, | |||
| "/datasets/<uuid:dataset_id>/document/create_by_text", | |||
| "/datasets/<uuid:dataset_id>/document/create-by-text", | |||
| ) | |||
| api.add_resource( | |||
| DocumentAddByFileApi, | |||
| "/datasets/<uuid:dataset_id>/document/create_by_file", | |||
| "/datasets/<uuid:dataset_id>/document/create-by-file", | |||
| ) | |||
| api.add_resource( | |||
| DocumentUpdateByTextApi, | |||
| "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/update_by_text", | |||
| "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/update-by-text", | |||
| ) | |||
| api.add_resource( | |||
| DocumentUpdateByFileApi, | |||
| "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/update_by_file", | |||
| "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/update-by-file", | |||
| ) | |||
| api.add_resource(DocumentApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>") | |||
| api.add_resource(DocumentListApi, "/datasets/<uuid:dataset_id>/documents") | |||
| api.add_resource(DocumentIndexingStatusApi, "/datasets/<uuid:dataset_id>/documents/<string:batch>/indexing-status") | |||
| @@ -1,11 +1,26 @@ | |||
| from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase | |||
| from controllers.service_api import api | |||
| from controllers.service_api import service_api_ns | |||
| from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_rate_limit_check | |||
| @service_api_ns.route("/datasets/<uuid:dataset_id>/hit-testing", "/datasets/<uuid:dataset_id>/retrieve") | |||
| class HitTestingApi(DatasetApiResource, DatasetsHitTestingBase): | |||
| @service_api_ns.doc("dataset_hit_testing") | |||
| @service_api_ns.doc(description="Perform hit testing on a dataset") | |||
| @service_api_ns.doc(params={"dataset_id": "Dataset ID"}) | |||
| @service_api_ns.doc( | |||
| responses={ | |||
| 200: "Hit testing results", | |||
| 401: "Unauthorized - invalid API token", | |||
| 404: "Dataset not found", | |||
| } | |||
| ) | |||
| @cloud_edition_billing_rate_limit_check("knowledge", "dataset") | |||
| def post(self, tenant_id, dataset_id): | |||
| """Perform hit testing on a dataset. | |||
| Tests retrieval performance for the specified dataset. | |||
| """ | |||
| dataset_id_str = str(dataset_id) | |||
| dataset = self.get_and_validate_dataset(dataset_id_str) | |||
| @@ -13,6 +28,3 @@ class HitTestingApi(DatasetApiResource, DatasetsHitTestingBase): | |||
| self.hit_testing_args_check(args) | |||
| return self.perform_hit_testing(dataset, args) | |||
| api.add_resource(HitTestingApi, "/datasets/<uuid:dataset_id>/hit-testing", "/datasets/<uuid:dataset_id>/retrieve") | |||
| @@ -4,7 +4,7 @@ from flask_login import current_user # type: ignore | |||
| from flask_restx import marshal, reqparse | |||
| from werkzeug.exceptions import NotFound | |||
| from controllers.service_api import api | |||
| from controllers.service_api import service_api_ns | |||
| from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_rate_limit_check | |||
| from fields.dataset_fields import dataset_metadata_fields | |||
| from services.dataset_service import DatasetService | |||
| @@ -14,14 +14,43 @@ from services.entities.knowledge_entities.knowledge_entities import ( | |||
| ) | |||
| from services.metadata_service import MetadataService | |||
| # Define parsers for metadata APIs | |||
| metadata_create_parser = reqparse.RequestParser() | |||
| metadata_create_parser.add_argument( | |||
| "type", type=str, required=True, nullable=False, location="json", help="Metadata type" | |||
| ) | |||
| metadata_create_parser.add_argument( | |||
| "name", type=str, required=True, nullable=False, location="json", help="Metadata name" | |||
| ) | |||
| metadata_update_parser = reqparse.RequestParser() | |||
| metadata_update_parser.add_argument( | |||
| "name", type=str, required=True, nullable=False, location="json", help="New metadata name" | |||
| ) | |||
| document_metadata_parser = reqparse.RequestParser() | |||
| document_metadata_parser.add_argument( | |||
| "operation_data", type=list, required=True, nullable=False, location="json", help="Metadata operation data" | |||
| ) | |||
| @service_api_ns.route("/datasets/<uuid:dataset_id>/metadata") | |||
| class DatasetMetadataCreateServiceApi(DatasetApiResource): | |||
| @service_api_ns.expect(metadata_create_parser) | |||
| @service_api_ns.doc("create_dataset_metadata") | |||
| @service_api_ns.doc(description="Create metadata for a dataset") | |||
| @service_api_ns.doc(params={"dataset_id": "Dataset ID"}) | |||
| @service_api_ns.doc( | |||
| responses={ | |||
| 201: "Metadata created successfully", | |||
| 401: "Unauthorized - invalid API token", | |||
| 404: "Dataset not found", | |||
| } | |||
| ) | |||
| @cloud_edition_billing_rate_limit_check("knowledge", "dataset") | |||
| def post(self, tenant_id, dataset_id): | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("type", type=str, required=True, nullable=False, location="json") | |||
| parser.add_argument("name", type=str, required=True, nullable=False, location="json") | |||
| args = parser.parse_args() | |||
| """Create metadata for a dataset.""" | |||
| args = metadata_create_parser.parse_args() | |||
| metadata_args = MetadataArgs(**args) | |||
| dataset_id_str = str(dataset_id) | |||
| @@ -33,7 +62,18 @@ class DatasetMetadataCreateServiceApi(DatasetApiResource): | |||
| metadata = MetadataService.create_metadata(dataset_id_str, metadata_args) | |||
| return marshal(metadata, dataset_metadata_fields), 201 | |||
| @service_api_ns.doc("get_dataset_metadata") | |||
| @service_api_ns.doc(description="Get all metadata for a dataset") | |||
| @service_api_ns.doc(params={"dataset_id": "Dataset ID"}) | |||
| @service_api_ns.doc( | |||
| responses={ | |||
| 200: "Metadata retrieved successfully", | |||
| 401: "Unauthorized - invalid API token", | |||
| 404: "Dataset not found", | |||
| } | |||
| ) | |||
| def get(self, tenant_id, dataset_id): | |||
| """Get all metadata for a dataset.""" | |||
| dataset_id_str = str(dataset_id) | |||
| dataset = DatasetService.get_dataset(dataset_id_str) | |||
| if dataset is None: | |||
| @@ -41,12 +81,23 @@ class DatasetMetadataCreateServiceApi(DatasetApiResource): | |||
| return MetadataService.get_dataset_metadatas(dataset), 200 | |||
| @service_api_ns.route("/datasets/<uuid:dataset_id>/metadata/<uuid:metadata_id>") | |||
| class DatasetMetadataServiceApi(DatasetApiResource): | |||
| @service_api_ns.expect(metadata_update_parser) | |||
| @service_api_ns.doc("update_dataset_metadata") | |||
| @service_api_ns.doc(description="Update metadata name") | |||
| @service_api_ns.doc(params={"dataset_id": "Dataset ID", "metadata_id": "Metadata ID"}) | |||
| @service_api_ns.doc( | |||
| responses={ | |||
| 200: "Metadata updated successfully", | |||
| 401: "Unauthorized - invalid API token", | |||
| 404: "Dataset or metadata not found", | |||
| } | |||
| ) | |||
| @cloud_edition_billing_rate_limit_check("knowledge", "dataset") | |||
| def patch(self, tenant_id, dataset_id, metadata_id): | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("name", type=str, required=True, nullable=False, location="json") | |||
| args = parser.parse_args() | |||
| """Update metadata name.""" | |||
| args = metadata_update_parser.parse_args() | |||
| dataset_id_str = str(dataset_id) | |||
| metadata_id_str = str(metadata_id) | |||
| @@ -58,8 +109,19 @@ class DatasetMetadataServiceApi(DatasetApiResource): | |||
| metadata = MetadataService.update_metadata_name(dataset_id_str, metadata_id_str, args.get("name")) | |||
| return marshal(metadata, dataset_metadata_fields), 200 | |||
| @service_api_ns.doc("delete_dataset_metadata") | |||
| @service_api_ns.doc(description="Delete metadata") | |||
| @service_api_ns.doc(params={"dataset_id": "Dataset ID", "metadata_id": "Metadata ID"}) | |||
| @service_api_ns.doc( | |||
| responses={ | |||
| 204: "Metadata deleted successfully", | |||
| 401: "Unauthorized - invalid API token", | |||
| 404: "Dataset or metadata not found", | |||
| } | |||
| ) | |||
| @cloud_edition_billing_rate_limit_check("knowledge", "dataset") | |||
| def delete(self, tenant_id, dataset_id, metadata_id): | |||
| """Delete metadata.""" | |||
| dataset_id_str = str(dataset_id) | |||
| metadata_id_str = str(metadata_id) | |||
| dataset = DatasetService.get_dataset(dataset_id_str) | |||
| @@ -71,15 +133,37 @@ class DatasetMetadataServiceApi(DatasetApiResource): | |||
| return 204 | |||
| @service_api_ns.route("/datasets/metadata/built-in") | |||
| class DatasetMetadataBuiltInFieldServiceApi(DatasetApiResource): | |||
| @service_api_ns.doc("get_built_in_fields") | |||
| @service_api_ns.doc(description="Get all built-in metadata fields") | |||
| @service_api_ns.doc( | |||
| responses={ | |||
| 200: "Built-in fields retrieved successfully", | |||
| 401: "Unauthorized - invalid API token", | |||
| } | |||
| ) | |||
| def get(self, tenant_id): | |||
| """Get all built-in metadata fields.""" | |||
| built_in_fields = MetadataService.get_built_in_fields() | |||
| return {"fields": built_in_fields}, 200 | |||
| @service_api_ns.route("/datasets/<uuid:dataset_id>/metadata/built-in/<string:action>") | |||
| class DatasetMetadataBuiltInFieldActionServiceApi(DatasetApiResource): | |||
| @service_api_ns.doc("toggle_built_in_field") | |||
| @service_api_ns.doc(description="Enable or disable built-in metadata field") | |||
| @service_api_ns.doc(params={"dataset_id": "Dataset ID", "action": "Action to perform: 'enable' or 'disable'"}) | |||
| @service_api_ns.doc( | |||
| responses={ | |||
| 200: "Action completed successfully", | |||
| 401: "Unauthorized - invalid API token", | |||
| 404: "Dataset not found", | |||
| } | |||
| ) | |||
| @cloud_edition_billing_rate_limit_check("knowledge", "dataset") | |||
| def post(self, tenant_id, dataset_id, action: Literal["enable", "disable"]): | |||
| """Enable or disable built-in metadata field.""" | |||
| dataset_id_str = str(dataset_id) | |||
| dataset = DatasetService.get_dataset(dataset_id_str) | |||
| if dataset is None: | |||
| @@ -93,29 +177,31 @@ class DatasetMetadataBuiltInFieldActionServiceApi(DatasetApiResource): | |||
| return 200 | |||
| @service_api_ns.route("/datasets/<uuid:dataset_id>/documents/metadata") | |||
| class DocumentMetadataEditServiceApi(DatasetApiResource): | |||
| @service_api_ns.expect(document_metadata_parser) | |||
| @service_api_ns.doc("update_documents_metadata") | |||
| @service_api_ns.doc(description="Update metadata for multiple documents") | |||
| @service_api_ns.doc(params={"dataset_id": "Dataset ID"}) | |||
| @service_api_ns.doc( | |||
| responses={ | |||
| 200: "Documents metadata updated successfully", | |||
| 401: "Unauthorized - invalid API token", | |||
| 404: "Dataset not found", | |||
| } | |||
| ) | |||
| @cloud_edition_billing_rate_limit_check("knowledge", "dataset") | |||
| def post(self, tenant_id, dataset_id): | |||
| """Update metadata for multiple documents.""" | |||
| dataset_id_str = str(dataset_id) | |||
| dataset = DatasetService.get_dataset(dataset_id_str) | |||
| if dataset is None: | |||
| raise NotFound("Dataset not found.") | |||
| DatasetService.check_dataset_permission(dataset, current_user) | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("operation_data", type=list, required=True, nullable=False, location="json") | |||
| args = parser.parse_args() | |||
| args = document_metadata_parser.parse_args() | |||
| metadata_args = MetadataOperationData(**args) | |||
| MetadataService.update_documents_metadata(dataset, metadata_args) | |||
| return 200 | |||
| api.add_resource(DatasetMetadataCreateServiceApi, "/datasets/<uuid:dataset_id>/metadata") | |||
| api.add_resource(DatasetMetadataServiceApi, "/datasets/<uuid:dataset_id>/metadata/<uuid:metadata_id>") | |||
| api.add_resource(DatasetMetadataBuiltInFieldServiceApi, "/datasets/metadata/built-in") | |||
| api.add_resource( | |||
| DatasetMetadataBuiltInFieldActionServiceApi, "/datasets/<uuid:dataset_id>/metadata/built-in/<string:action>" | |||
| ) | |||
| api.add_resource(DocumentMetadataEditServiceApi, "/datasets/<uuid:dataset_id>/documents/metadata") | |||
| @@ -3,7 +3,7 @@ from flask_login import current_user | |||
| from flask_restx import marshal, reqparse | |||
| from werkzeug.exceptions import NotFound | |||
| from controllers.service_api import api | |||
| from controllers.service_api import service_api_ns | |||
| from controllers.service_api.app.error import ProviderNotInitializeError | |||
| from controllers.service_api.wraps import ( | |||
| DatasetApiResource, | |||
| @@ -19,34 +19,59 @@ from fields.segment_fields import child_chunk_fields, segment_fields | |||
| from models.dataset import Dataset | |||
| from services.dataset_service import DatasetService, DocumentService, SegmentService | |||
| from services.entities.knowledge_entities.knowledge_entities import SegmentUpdateArgs | |||
| from services.errors.chunk import ( | |||
| ChildChunkDeleteIndexError, | |||
| ChildChunkIndexingError, | |||
| ) | |||
| from services.errors.chunk import ( | |||
| ChildChunkDeleteIndexError as ChildChunkDeleteIndexServiceError, | |||
| ) | |||
| from services.errors.chunk import ( | |||
| ChildChunkIndexingError as ChildChunkIndexingServiceError, | |||
| ) | |||
| from services.errors.chunk import ChildChunkDeleteIndexError, ChildChunkIndexingError | |||
| from services.errors.chunk import ChildChunkDeleteIndexError as ChildChunkDeleteIndexServiceError | |||
| from services.errors.chunk import ChildChunkIndexingError as ChildChunkIndexingServiceError | |||
| # Define parsers for segment operations | |||
| segment_create_parser = reqparse.RequestParser() | |||
| segment_create_parser.add_argument("segments", type=list, required=False, nullable=True, location="json") | |||
| segment_list_parser = reqparse.RequestParser() | |||
| segment_list_parser.add_argument("status", type=str, action="append", default=[], location="args") | |||
| segment_list_parser.add_argument("keyword", type=str, default=None, location="args") | |||
| segment_update_parser = reqparse.RequestParser() | |||
| segment_update_parser.add_argument("segment", type=dict, required=False, nullable=True, location="json") | |||
| child_chunk_create_parser = reqparse.RequestParser() | |||
| child_chunk_create_parser.add_argument("content", type=str, required=True, nullable=False, location="json") | |||
| child_chunk_list_parser = reqparse.RequestParser() | |||
| child_chunk_list_parser.add_argument("limit", type=int, default=20, location="args") | |||
| child_chunk_list_parser.add_argument("keyword", type=str, default=None, location="args") | |||
| child_chunk_list_parser.add_argument("page", type=int, default=1, location="args") | |||
| child_chunk_update_parser = reqparse.RequestParser() | |||
| child_chunk_update_parser.add_argument("content", type=str, required=True, nullable=False, location="json") | |||
| @service_api_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments") | |||
| class SegmentApi(DatasetApiResource): | |||
| """Resource for segments.""" | |||
| @service_api_ns.expect(segment_create_parser) | |||
| @service_api_ns.doc("create_segments") | |||
| @service_api_ns.doc(description="Create segments in a document") | |||
| @service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"}) | |||
| @service_api_ns.doc( | |||
| responses={ | |||
| 200: "Segments created successfully", | |||
| 400: "Bad request - segments data is missing", | |||
| 401: "Unauthorized - invalid API token", | |||
| 404: "Dataset or document not found", | |||
| } | |||
| ) | |||
| @cloud_edition_billing_resource_check("vector_space", "dataset") | |||
| @cloud_edition_billing_knowledge_limit_check("add_segment", "dataset") | |||
| @cloud_edition_billing_rate_limit_check("knowledge", "dataset") | |||
| def post(self, tenant_id, dataset_id, document_id): | |||
| def post(self, tenant_id: str, dataset_id: str, document_id: str): | |||
| """Create single segment.""" | |||
| # check dataset | |||
| dataset_id = str(dataset_id) | |||
| tenant_id = str(tenant_id) | |||
| dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() | |||
| if not dataset: | |||
| raise NotFound("Dataset not found.") | |||
| # check document | |||
| document_id = str(document_id) | |||
| document = DocumentService.get_document(dataset.id, document_id) | |||
| if not document: | |||
| raise NotFound("Document not found.") | |||
| @@ -71,9 +96,7 @@ class SegmentApi(DatasetApiResource): | |||
| except ProviderTokenNotInitError as ex: | |||
| raise ProviderNotInitializeError(ex.description) | |||
| # validate args | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("segments", type=list, required=False, nullable=True, location="json") | |||
| args = parser.parse_args() | |||
| args = segment_create_parser.parse_args() | |||
| if args["segments"] is not None: | |||
| for args_item in args["segments"]: | |||
| SegmentService.segment_create_args_validate(args_item, document) | |||
| @@ -82,18 +105,26 @@ class SegmentApi(DatasetApiResource): | |||
| else: | |||
| return {"error": "Segments is required"}, 400 | |||
| def get(self, tenant_id, dataset_id, document_id): | |||
| @service_api_ns.expect(segment_list_parser) | |||
| @service_api_ns.doc("list_segments") | |||
| @service_api_ns.doc(description="List segments in a document") | |||
| @service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"}) | |||
| @service_api_ns.doc( | |||
| responses={ | |||
| 200: "Segments retrieved successfully", | |||
| 401: "Unauthorized - invalid API token", | |||
| 404: "Dataset or document not found", | |||
| } | |||
| ) | |||
| def get(self, tenant_id: str, dataset_id: str, document_id: str): | |||
| """Get segments.""" | |||
| # check dataset | |||
| dataset_id = str(dataset_id) | |||
| tenant_id = str(tenant_id) | |||
| page = request.args.get("page", default=1, type=int) | |||
| limit = request.args.get("limit", default=20, type=int) | |||
| dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() | |||
| if not dataset: | |||
| raise NotFound("Dataset not found.") | |||
| # check document | |||
| document_id = str(document_id) | |||
| document = DocumentService.get_document(dataset.id, document_id) | |||
| if not document: | |||
| raise NotFound("Document not found.") | |||
| @@ -114,10 +145,7 @@ class SegmentApi(DatasetApiResource): | |||
| except ProviderTokenNotInitError as ex: | |||
| raise ProviderNotInitializeError(ex.description) | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("status", type=str, action="append", default=[], location="args") | |||
| parser.add_argument("keyword", type=str, default=None, location="args") | |||
| args = parser.parse_args() | |||
| args = segment_list_parser.parse_args() | |||
| segments, total = SegmentService.get_segments( | |||
| document_id=document_id, | |||
| @@ -140,43 +168,62 @@ class SegmentApi(DatasetApiResource): | |||
| return response, 200 | |||
| @service_api_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>") | |||
| class DatasetSegmentApi(DatasetApiResource): | |||
| @service_api_ns.doc("delete_segment") | |||
| @service_api_ns.doc(description="Delete a specific segment") | |||
| @service_api_ns.doc( | |||
| params={"dataset_id": "Dataset ID", "document_id": "Document ID", "segment_id": "Segment ID to delete"} | |||
| ) | |||
| @service_api_ns.doc( | |||
| responses={ | |||
| 204: "Segment deleted successfully", | |||
| 401: "Unauthorized - invalid API token", | |||
| 404: "Dataset, document, or segment not found", | |||
| } | |||
| ) | |||
| @cloud_edition_billing_rate_limit_check("knowledge", "dataset") | |||
| def delete(self, tenant_id, dataset_id, document_id, segment_id): | |||
| def delete(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str): | |||
| # check dataset | |||
| dataset_id = str(dataset_id) | |||
| tenant_id = str(tenant_id) | |||
| dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() | |||
| if not dataset: | |||
| raise NotFound("Dataset not found.") | |||
| # check user's model setting | |||
| DatasetService.check_dataset_model_setting(dataset) | |||
| # check document | |||
| document_id = str(document_id) | |||
| document = DocumentService.get_document(dataset_id, document_id) | |||
| if not document: | |||
| raise NotFound("Document not found.") | |||
| # check segment | |||
| segment_id = str(segment_id) | |||
| segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_user.current_tenant_id) | |||
| if not segment: | |||
| raise NotFound("Segment not found.") | |||
| SegmentService.delete_segment(segment, document, dataset) | |||
| return 204 | |||
| @service_api_ns.expect(segment_update_parser) | |||
| @service_api_ns.doc("update_segment") | |||
| @service_api_ns.doc(description="Update a specific segment") | |||
| @service_api_ns.doc( | |||
| params={"dataset_id": "Dataset ID", "document_id": "Document ID", "segment_id": "Segment ID to update"} | |||
| ) | |||
| @service_api_ns.doc( | |||
| responses={ | |||
| 200: "Segment updated successfully", | |||
| 401: "Unauthorized - invalid API token", | |||
| 404: "Dataset, document, or segment not found", | |||
| } | |||
| ) | |||
| @cloud_edition_billing_resource_check("vector_space", "dataset") | |||
| @cloud_edition_billing_rate_limit_check("knowledge", "dataset") | |||
| def post(self, tenant_id, dataset_id, document_id, segment_id): | |||
| def post(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str): | |||
| # check dataset | |||
| dataset_id = str(dataset_id) | |||
| tenant_id = str(tenant_id) | |||
| dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() | |||
| if not dataset: | |||
| raise NotFound("Dataset not found.") | |||
| # check user's model setting | |||
| DatasetService.check_dataset_model_setting(dataset) | |||
| # check document | |||
| document_id = str(document_id) | |||
| document = DocumentService.get_document(dataset_id, document_id) | |||
| if not document: | |||
| raise NotFound("Document not found.") | |||
| @@ -197,37 +244,39 @@ class DatasetSegmentApi(DatasetApiResource): | |||
| except ProviderTokenNotInitError as ex: | |||
| raise ProviderNotInitializeError(ex.description) | |||
| # check segment | |||
| segment_id = str(segment_id) | |||
| segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_user.current_tenant_id) | |||
| if not segment: | |||
| raise NotFound("Segment not found.") | |||
| # validate args | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("segment", type=dict, required=False, nullable=True, location="json") | |||
| args = parser.parse_args() | |||
| args = segment_update_parser.parse_args() | |||
| updated_segment = SegmentService.update_segment( | |||
| SegmentUpdateArgs(**args["segment"]), segment, document, dataset | |||
| ) | |||
| return {"data": marshal(updated_segment, segment_fields), "doc_form": document.doc_form}, 200 | |||
| def get(self, tenant_id, dataset_id, document_id, segment_id): | |||
| @service_api_ns.doc("get_segment") | |||
| @service_api_ns.doc(description="Get a specific segment by ID") | |||
| @service_api_ns.doc( | |||
| responses={ | |||
| 200: "Segment retrieved successfully", | |||
| 401: "Unauthorized - invalid API token", | |||
| 404: "Dataset, document, or segment not found", | |||
| } | |||
| ) | |||
| def get(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str): | |||
| # check dataset | |||
| dataset_id = str(dataset_id) | |||
| tenant_id = str(tenant_id) | |||
| dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() | |||
| if not dataset: | |||
| raise NotFound("Dataset not found.") | |||
| # check user's model setting | |||
| DatasetService.check_dataset_model_setting(dataset) | |||
| # check document | |||
| document_id = str(document_id) | |||
| document = DocumentService.get_document(dataset_id, document_id) | |||
| if not document: | |||
| raise NotFound("Document not found.") | |||
| # check segment | |||
| segment_id = str(segment_id) | |||
| segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_user.current_tenant_id) | |||
| if not segment: | |||
| raise NotFound("Segment not found.") | |||
| @@ -235,29 +284,41 @@ class DatasetSegmentApi(DatasetApiResource): | |||
| return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200 | |||
| @service_api_ns.route( | |||
| "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>/child_chunks" | |||
| ) | |||
| class ChildChunkApi(DatasetApiResource): | |||
| """Resource for child chunks.""" | |||
| @service_api_ns.expect(child_chunk_create_parser) | |||
| @service_api_ns.doc("create_child_chunk") | |||
| @service_api_ns.doc(description="Create a new child chunk for a segment") | |||
| @service_api_ns.doc( | |||
| params={"dataset_id": "Dataset ID", "document_id": "Document ID", "segment_id": "Parent segment ID"} | |||
| ) | |||
| @service_api_ns.doc( | |||
| responses={ | |||
| 200: "Child chunk created successfully", | |||
| 401: "Unauthorized - invalid API token", | |||
| 404: "Dataset, document, or segment not found", | |||
| } | |||
| ) | |||
| @cloud_edition_billing_resource_check("vector_space", "dataset") | |||
| @cloud_edition_billing_knowledge_limit_check("add_segment", "dataset") | |||
| @cloud_edition_billing_rate_limit_check("knowledge", "dataset") | |||
| def post(self, tenant_id, dataset_id, document_id, segment_id): | |||
| def post(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str): | |||
| """Create child chunk.""" | |||
| # check dataset | |||
| dataset_id = str(dataset_id) | |||
| tenant_id = str(tenant_id) | |||
| dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() | |||
| if not dataset: | |||
| raise NotFound("Dataset not found.") | |||
| # check document | |||
| document_id = str(document_id) | |||
| document = DocumentService.get_document(dataset.id, document_id) | |||
| if not document: | |||
| raise NotFound("Document not found.") | |||
| # check segment | |||
| segment_id = str(segment_id) | |||
| segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_user.current_tenant_id) | |||
| if not segment: | |||
| raise NotFound("Segment not found.") | |||
| @@ -280,43 +341,46 @@ class ChildChunkApi(DatasetApiResource): | |||
| raise ProviderNotInitializeError(ex.description) | |||
| # validate args | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("content", type=str, required=True, nullable=False, location="json") | |||
| args = parser.parse_args() | |||
| args = child_chunk_create_parser.parse_args() | |||
| try: | |||
| child_chunk = SegmentService.create_child_chunk(args.get("content"), segment, document, dataset) | |||
| child_chunk = SegmentService.create_child_chunk(args["content"], segment, document, dataset) | |||
| except ChildChunkIndexingServiceError as e: | |||
| raise ChildChunkIndexingError(str(e)) | |||
| return {"data": marshal(child_chunk, child_chunk_fields)}, 200 | |||
| def get(self, tenant_id, dataset_id, document_id, segment_id): | |||
| @service_api_ns.expect(child_chunk_list_parser) | |||
| @service_api_ns.doc("list_child_chunks") | |||
| @service_api_ns.doc(description="List child chunks for a segment") | |||
| @service_api_ns.doc( | |||
| params={"dataset_id": "Dataset ID", "document_id": "Document ID", "segment_id": "Parent segment ID"} | |||
| ) | |||
| @service_api_ns.doc( | |||
| responses={ | |||
| 200: "Child chunks retrieved successfully", | |||
| 401: "Unauthorized - invalid API token", | |||
| 404: "Dataset, document, or segment not found", | |||
| } | |||
| ) | |||
| def get(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str): | |||
| """Get child chunks.""" | |||
| # check dataset | |||
| dataset_id = str(dataset_id) | |||
| tenant_id = str(tenant_id) | |||
| dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() | |||
| if not dataset: | |||
| raise NotFound("Dataset not found.") | |||
| # check document | |||
| document_id = str(document_id) | |||
| document = DocumentService.get_document(dataset.id, document_id) | |||
| if not document: | |||
| raise NotFound("Document not found.") | |||
| # check segment | |||
| segment_id = str(segment_id) | |||
| segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_user.current_tenant_id) | |||
| if not segment: | |||
| raise NotFound("Segment not found.") | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("limit", type=int, default=20, location="args") | |||
| parser.add_argument("keyword", type=str, default=None, location="args") | |||
| parser.add_argument("page", type=int, default=1, location="args") | |||
| args = parser.parse_args() | |||
| args = child_chunk_list_parser.parse_args() | |||
| page = args["page"] | |||
| limit = min(args["limit"], 100) | |||
| @@ -333,28 +397,44 @@ class ChildChunkApi(DatasetApiResource): | |||
| }, 200 | |||
| @service_api_ns.route( | |||
| "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>/child_chunks/<uuid:child_chunk_id>" | |||
| ) | |||
| class DatasetChildChunkApi(DatasetApiResource): | |||
| """Resource for updating child chunks.""" | |||
| @service_api_ns.doc("delete_child_chunk") | |||
| @service_api_ns.doc(description="Delete a specific child chunk") | |||
| @service_api_ns.doc( | |||
| params={ | |||
| "dataset_id": "Dataset ID", | |||
| "document_id": "Document ID", | |||
| "segment_id": "Parent segment ID", | |||
| "child_chunk_id": "Child chunk ID to delete", | |||
| } | |||
| ) | |||
| @service_api_ns.doc( | |||
| responses={ | |||
| 204: "Child chunk deleted successfully", | |||
| 401: "Unauthorized - invalid API token", | |||
| 404: "Dataset, document, segment, or child chunk not found", | |||
| } | |||
| ) | |||
| @cloud_edition_billing_knowledge_limit_check("add_segment", "dataset") | |||
| @cloud_edition_billing_rate_limit_check("knowledge", "dataset") | |||
| def delete(self, tenant_id, dataset_id, document_id, segment_id, child_chunk_id): | |||
| def delete(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str, child_chunk_id: str): | |||
| """Delete child chunk.""" | |||
| # check dataset | |||
| dataset_id = str(dataset_id) | |||
| tenant_id = str(tenant_id) | |||
| dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() | |||
| if not dataset: | |||
| raise NotFound("Dataset not found.") | |||
| # check document | |||
| document_id = str(document_id) | |||
| document = DocumentService.get_document(dataset.id, document_id) | |||
| if not document: | |||
| raise NotFound("Document not found.") | |||
| # check segment | |||
| segment_id = str(segment_id) | |||
| segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_user.current_tenant_id) | |||
| if not segment: | |||
| raise NotFound("Segment not found.") | |||
| @@ -364,7 +444,6 @@ class DatasetChildChunkApi(DatasetApiResource): | |||
| raise NotFound("Document not found.") | |||
| # check child chunk | |||
| child_chunk_id = str(child_chunk_id) | |||
| child_chunk = SegmentService.get_child_chunk_by_id( | |||
| child_chunk_id=child_chunk_id, tenant_id=current_user.current_tenant_id | |||
| ) | |||
| @@ -382,14 +461,30 @@ class DatasetChildChunkApi(DatasetApiResource): | |||
| return 204 | |||
| @service_api_ns.expect(child_chunk_update_parser) | |||
| @service_api_ns.doc("update_child_chunk") | |||
| @service_api_ns.doc(description="Update a specific child chunk") | |||
| @service_api_ns.doc( | |||
| params={ | |||
| "dataset_id": "Dataset ID", | |||
| "document_id": "Document ID", | |||
| "segment_id": "Parent segment ID", | |||
| "child_chunk_id": "Child chunk ID to update", | |||
| } | |||
| ) | |||
| @service_api_ns.doc( | |||
| responses={ | |||
| 200: "Child chunk updated successfully", | |||
| 401: "Unauthorized - invalid API token", | |||
| 404: "Dataset, document, segment, or child chunk not found", | |||
| } | |||
| ) | |||
| @cloud_edition_billing_resource_check("vector_space", "dataset") | |||
| @cloud_edition_billing_knowledge_limit_check("add_segment", "dataset") | |||
| @cloud_edition_billing_rate_limit_check("knowledge", "dataset") | |||
| def patch(self, tenant_id, dataset_id, document_id, segment_id, child_chunk_id): | |||
| def patch(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str, child_chunk_id: str): | |||
| """Update child chunk.""" | |||
| # check dataset | |||
| dataset_id = str(dataset_id) | |||
| tenant_id = str(tenant_id) | |||
| dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() | |||
| if not dataset: | |||
| raise NotFound("Dataset not found.") | |||
| @@ -420,28 +515,11 @@ class DatasetChildChunkApi(DatasetApiResource): | |||
| raise NotFound("Child chunk not found.") | |||
| # validate args | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("content", type=str, required=True, nullable=False, location="json") | |||
| args = parser.parse_args() | |||
| args = child_chunk_update_parser.parse_args() | |||
| try: | |||
| child_chunk = SegmentService.update_child_chunk( | |||
| args.get("content"), child_chunk, segment, document, dataset | |||
| ) | |||
| child_chunk = SegmentService.update_child_chunk(args["content"], child_chunk, segment, document, dataset) | |||
| except ChildChunkIndexingServiceError as e: | |||
| raise ChildChunkIndexingError(str(e)) | |||
| return {"data": marshal(child_chunk, child_chunk_fields)}, 200 | |||
| api.add_resource(SegmentApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments") | |||
| api.add_resource( | |||
| DatasetSegmentApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>" | |||
| ) | |||
| api.add_resource( | |||
| ChildChunkApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>/child_chunks" | |||
| ) | |||
| api.add_resource( | |||
| DatasetChildChunkApi, | |||
| "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>/child_chunks/<uuid:child_chunk_id>", | |||
| ) | |||
| @@ -1,6 +1,6 @@ | |||
| from werkzeug.exceptions import NotFound | |||
| from controllers.service_api import api | |||
| from controllers.service_api import service_api_ns | |||
| from controllers.service_api.wraps import ( | |||
| DatasetApiResource, | |||
| ) | |||
| @@ -11,9 +11,23 @@ from models.model import UploadFile | |||
| from services.dataset_service import DocumentService | |||
| @service_api_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/upload-file") | |||
| class UploadFileApi(DatasetApiResource): | |||
| @service_api_ns.doc("get_upload_file") | |||
| @service_api_ns.doc(description="Get upload file information and download URL") | |||
| @service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"}) | |||
| @service_api_ns.doc( | |||
| responses={ | |||
| 200: "Upload file information retrieved successfully", | |||
| 401: "Unauthorized - invalid API token", | |||
| 404: "Dataset, document, or upload file not found", | |||
| } | |||
| ) | |||
| def get(self, tenant_id, dataset_id, document_id): | |||
| """Get upload file.""" | |||
| """Get upload file information and download URL. | |||
| Returns information about an uploaded file including its download URL. | |||
| """ | |||
| # check dataset | |||
| dataset_id = str(dataset_id) | |||
| tenant_id = str(tenant_id) | |||
| @@ -49,6 +63,3 @@ class UploadFileApi(DatasetApiResource): | |||
| "created_by": upload_file.created_by, | |||
| "created_at": upload_file.created_at.timestamp(), | |||
| }, 200 | |||
| api.add_resource(UploadFileApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/upload-file") | |||
| @@ -1,9 +1,10 @@ | |||
| from flask_restx import Resource | |||
| from configs import dify_config | |||
| from controllers.service_api import api | |||
| from controllers.service_api import service_api_ns | |||
| @service_api_ns.route("/") | |||
| class IndexApi(Resource): | |||
| def get(self): | |||
| return { | |||
| @@ -11,6 +12,3 @@ class IndexApi(Resource): | |||
| "api_version": "v1", | |||
| "server_version": dify_config.project.version, | |||
| } | |||
| api.add_resource(IndexApi, "/") | |||
| @@ -1,21 +1,32 @@ | |||
| from flask_login import current_user | |||
| from flask_restx import Resource | |||
| from controllers.service_api import api | |||
| from controllers.service_api import service_api_ns | |||
| from controllers.service_api.wraps import validate_dataset_token | |||
| from core.model_runtime.utils.encoders import jsonable_encoder | |||
| from services.model_provider_service import ModelProviderService | |||
| @service_api_ns.route("/workspaces/current/models/model-types/<string:model_type>") | |||
| class ModelProviderAvailableModelApi(Resource): | |||
| @service_api_ns.doc("get_available_models") | |||
| @service_api_ns.doc(description="Get available models by model type") | |||
| @service_api_ns.doc(params={"model_type": "Type of model to retrieve"}) | |||
| @service_api_ns.doc( | |||
| responses={ | |||
| 200: "Models retrieved successfully", | |||
| 401: "Unauthorized - invalid API token", | |||
| } | |||
| ) | |||
| @validate_dataset_token | |||
| def get(self, _, model_type): | |||
| """Get available models by model type. | |||
| Returns a list of available models for the specified model type. | |||
| """ | |||
| tenant_id = current_user.current_tenant_id | |||
| model_provider_service = ModelProviderService() | |||
| models = model_provider_service.get_models_by_model_type(tenant_id=tenant_id, model_type=model_type) | |||
| return jsonable_encoder({"data": models}) | |||
| api.add_resource(ModelProviderAvailableModelApi, "/workspaces/current/models/model-types/<string:model_type>") | |||
| @@ -1,4 +1,4 @@ | |||
| from flask_restx import fields | |||
| from flask_restx import Api, Namespace, fields | |||
| from libs.helper import TimestampField | |||
| @@ -11,6 +11,12 @@ annotation_fields = { | |||
| # 'account': fields.Nested(simple_account_fields, allow_null=True) | |||
| } | |||
| def build_annotation_model(api_or_ns: Api | Namespace): | |||
| """Build the annotation model for the API or Namespace.""" | |||
| return api_or_ns.model("Annotation", annotation_fields) | |||
| annotation_list_fields = { | |||
| "data": fields.List(fields.Nested(annotation_fields)), | |||
| } | |||
| @@ -1,4 +1,4 @@ | |||
| from flask_restx import fields | |||
| from flask_restx import Api, Namespace, fields | |||
| from fields.member_fields import simple_account_fields | |||
| from libs.helper import TimestampField | |||
| @@ -45,6 +45,12 @@ message_file_fields = { | |||
| "upload_file_id": fields.String(default=None), | |||
| } | |||
| def build_message_file_model(api_or_ns: Api | Namespace): | |||
| """Build the message file fields for the API or Namespace.""" | |||
| return api_or_ns.model("MessageFile", message_file_fields) | |||
| agent_thought_fields = { | |||
| "id": fields.String, | |||
| "chain_id": fields.String, | |||
| @@ -209,3 +215,22 @@ conversation_infinite_scroll_pagination_fields = { | |||
| "has_more": fields.Boolean, | |||
| "data": fields.List(fields.Nested(simple_conversation_fields)), | |||
| } | |||
| def build_conversation_infinite_scroll_pagination_model(api_or_ns: Api | Namespace): | |||
| """Build the conversation infinite scroll pagination model for the API or Namespace.""" | |||
| simple_conversation_model = build_simple_conversation_model(api_or_ns) | |||
| copied_fields = conversation_infinite_scroll_pagination_fields.copy() | |||
| copied_fields["data"] = fields.List(fields.Nested(simple_conversation_model)) | |||
| return api_or_ns.model("ConversationInfiniteScrollPagination", copied_fields) | |||
| def build_conversation_delete_model(api_or_ns: Api | Namespace): | |||
| """Build the conversation delete model for the API or Namespace.""" | |||
| return api_or_ns.model("ConversationDelete", conversation_delete_fields) | |||
| def build_simple_conversation_model(api_or_ns: Api | Namespace): | |||
| """Build the simple conversation model for the API or Namespace.""" | |||
| return api_or_ns.model("SimpleConversation", simple_conversation_fields) | |||
| @@ -1,4 +1,4 @@ | |||
| from flask_restx import fields | |||
| from flask_restx import Api, Namespace, fields | |||
| from libs.helper import TimestampField | |||
| @@ -27,3 +27,19 @@ conversation_variable_infinite_scroll_pagination_fields = { | |||
| "has_more": fields.Boolean, | |||
| "data": fields.List(fields.Nested(conversation_variable_fields)), | |||
| } | |||
| def build_conversation_variable_model(api_or_ns: Api | Namespace): | |||
| """Build the conversation variable model for the API or Namespace.""" | |||
| return api_or_ns.model("ConversationVariable", conversation_variable_fields) | |||
| def build_conversation_variable_infinite_scroll_pagination_model(api_or_ns: Api | Namespace): | |||
| """Build the conversation variable infinite scroll pagination model for the API or Namespace.""" | |||
| # Build the nested variable model first | |||
| conversation_variable_model = build_conversation_variable_model(api_or_ns) | |||
| copied_fields = conversation_variable_infinite_scroll_pagination_fields.copy() | |||
| copied_fields["data"] = fields.List(fields.Nested(conversation_variable_model)) | |||
| return api_or_ns.model("ConversationVariableInfiniteScrollPagination", copied_fields) | |||
| @@ -1,4 +1,4 @@ | |||
| from flask_restx import fields | |||
| from flask_restx import Api, Namespace, fields | |||
| simple_end_user_fields = { | |||
| "id": fields.String, | |||
| @@ -6,3 +6,7 @@ simple_end_user_fields = { | |||
| "is_anonymous": fields.Boolean, | |||
| "session_id": fields.String, | |||
| } | |||
| def build_simple_end_user_model(api_or_ns: Api | Namespace): | |||
| return api_or_ns.model("SimpleEndUser", simple_end_user_fields) | |||
| @@ -1,4 +1,4 @@ | |||
| from flask_restx import fields | |||
| from flask_restx import Api, Namespace, fields | |||
| from libs.helper import TimestampField | |||
| @@ -11,6 +11,19 @@ upload_config_fields = { | |||
| "workflow_file_upload_limit": fields.Integer, | |||
| } | |||
| def build_upload_config_model(api_or_ns: Api | Namespace): | |||
| """Build the upload config model for the API or Namespace. | |||
| Args: | |||
| api_or_ns: Flask-RestX Api or Namespace instance | |||
| Returns: | |||
| The registered model | |||
| """ | |||
| return api_or_ns.model("UploadConfig", upload_config_fields) | |||
| file_fields = { | |||
| "id": fields.String, | |||
| "name": fields.String, | |||
| @@ -22,12 +35,37 @@ file_fields = { | |||
| "preview_url": fields.String, | |||
| } | |||
| def build_file_model(api_or_ns: Api | Namespace): | |||
| """Build the file model for the API or Namespace. | |||
| Args: | |||
| api_or_ns: Flask-RestX Api or Namespace instance | |||
| Returns: | |||
| The registered model | |||
| """ | |||
| return api_or_ns.model("File", file_fields) | |||
| remote_file_info_fields = { | |||
| "file_type": fields.String(attribute="file_type"), | |||
| "file_length": fields.Integer(attribute="file_length"), | |||
| } | |||
| def build_remote_file_info_model(api_or_ns: Api | Namespace): | |||
| """Build the remote file info model for the API or Namespace. | |||
| Args: | |||
| api_or_ns: Flask-RestX Api or Namespace instance | |||
| Returns: | |||
| The registered model | |||
| """ | |||
| return api_or_ns.model("RemoteFileInfo", remote_file_info_fields) | |||
| file_fields_with_signed_url = { | |||
| "id": fields.String, | |||
| "name": fields.String, | |||
| @@ -38,3 +76,15 @@ file_fields_with_signed_url = { | |||
| "created_by": fields.String, | |||
| "created_at": TimestampField, | |||
| } | |||
| def build_file_with_signed_url_model(api_or_ns: Api | Namespace): | |||
| """Build the file with signed URL model for the API or Namespace. | |||
| Args: | |||
| api_or_ns: Flask-RestX Api or Namespace instance | |||
| Returns: | |||
| The registered model | |||
| """ | |||
| return api_or_ns.model("FileWithSignedUrl", file_fields_with_signed_url) | |||
| @@ -1,8 +1,17 @@ | |||
| from flask_restx import fields | |||
| from flask_restx import Api, Namespace, fields | |||
| from libs.helper import AvatarUrlField, TimestampField | |||
| simple_account_fields = {"id": fields.String, "name": fields.String, "email": fields.String} | |||
| simple_account_fields = { | |||
| "id": fields.String, | |||
| "name": fields.String, | |||
| "email": fields.String, | |||
| } | |||
| def build_simple_account_model(api_or_ns: Api | Namespace): | |||
| return api_or_ns.model("SimpleAccount", simple_account_fields) | |||
| account_fields = { | |||
| "id": fields.String, | |||
| @@ -1,11 +1,19 @@ | |||
| from flask_restx import fields | |||
| from flask_restx import Api, Namespace, fields | |||
| from fields.conversation_fields import message_file_fields | |||
| from libs.helper import TimestampField | |||
| from .raws import FilesContainedField | |||
| feedback_fields = {"rating": fields.String} | |||
| feedback_fields = { | |||
| "rating": fields.String, | |||
| } | |||
| def build_feedback_model(api_or_ns: Api | Namespace): | |||
| """Build the feedback model for the API or Namespace.""" | |||
| return api_or_ns.model("Feedback", feedback_fields) | |||
| agent_thought_fields = { | |||
| "id": fields.String, | |||
| @@ -21,6 +29,12 @@ agent_thought_fields = { | |||
| "files": fields.List(fields.String), | |||
| } | |||
| def build_agent_thought_model(api_or_ns: Api | Namespace): | |||
| """Build the agent thought model for the API or Namespace.""" | |||
| return api_or_ns.model("AgentThought", agent_thought_fields) | |||
| retriever_resource_fields = { | |||
| "id": fields.String, | |||
| "message_id": fields.String, | |||
| @@ -1,3 +1,12 @@ | |||
| from flask_restx import fields | |||
| from flask_restx import Api, Namespace, fields | |||
| tag_fields = {"id": fields.String, "name": fields.String, "type": fields.String, "binding_count": fields.String} | |||
| dataset_tag_fields = { | |||
| "id": fields.String, | |||
| "name": fields.String, | |||
| "type": fields.String, | |||
| "binding_count": fields.String, | |||
| } | |||
| def build_dataset_tag_fields(api_or_ns: Api | Namespace): | |||
| return api_or_ns.model("DataSetTag", dataset_tag_fields) | |||
| @@ -1,8 +1,8 @@ | |||
| from flask_restx import fields | |||
| from flask_restx import Api, Namespace, fields | |||
| from fields.end_user_fields import simple_end_user_fields | |||
| from fields.member_fields import simple_account_fields | |||
| from fields.workflow_run_fields import workflow_run_for_log_fields | |||
| from fields.end_user_fields import build_simple_end_user_model, simple_end_user_fields | |||
| from fields.member_fields import build_simple_account_model, simple_account_fields | |||
| from fields.workflow_run_fields import build_workflow_run_for_log_model, workflow_run_for_log_fields | |||
| from libs.helper import TimestampField | |||
| workflow_app_log_partial_fields = { | |||
| @@ -15,6 +15,24 @@ workflow_app_log_partial_fields = { | |||
| "created_at": TimestampField, | |||
| } | |||
| def build_workflow_app_log_partial_model(api_or_ns: Api | Namespace): | |||
| """Build the workflow app log partial model for the API or Namespace.""" | |||
| workflow_run_model = build_workflow_run_for_log_model(api_or_ns) | |||
| simple_account_model = build_simple_account_model(api_or_ns) | |||
| simple_end_user_model = build_simple_end_user_model(api_or_ns) | |||
| copied_fields = workflow_app_log_partial_fields.copy() | |||
| copied_fields["workflow_run"] = fields.Nested(workflow_run_model, attribute="workflow_run", allow_null=True) | |||
| copied_fields["created_by_account"] = fields.Nested( | |||
| simple_account_model, attribute="created_by_account", allow_null=True | |||
| ) | |||
| copied_fields["created_by_end_user"] = fields.Nested( | |||
| simple_end_user_model, attribute="created_by_end_user", allow_null=True | |||
| ) | |||
| return api_or_ns.model("WorkflowAppLogPartial", copied_fields) | |||
| workflow_app_log_pagination_fields = { | |||
| "page": fields.Integer, | |||
| "limit": fields.Integer, | |||
| @@ -22,3 +40,13 @@ workflow_app_log_pagination_fields = { | |||
| "has_more": fields.Boolean, | |||
| "data": fields.List(fields.Nested(workflow_app_log_partial_fields)), | |||
| } | |||
| def build_workflow_app_log_pagination_model(api_or_ns: Api | Namespace): | |||
| """Build the workflow app log pagination model for the API or Namespace.""" | |||
| # Build the nested partial model first | |||
| workflow_app_log_partial_model = build_workflow_app_log_partial_model(api_or_ns) | |||
| copied_fields = workflow_app_log_pagination_fields.copy() | |||
| copied_fields["data"] = fields.List(fields.Nested(workflow_app_log_partial_model)) | |||
| return api_or_ns.model("WorkflowAppLogPagination", copied_fields) | |||
| @@ -1,4 +1,4 @@ | |||
| from flask_restx import fields | |||
| from flask_restx import Api, Namespace, fields | |||
| from fields.end_user_fields import simple_end_user_fields | |||
| from fields.member_fields import simple_account_fields | |||
| @@ -17,6 +17,11 @@ workflow_run_for_log_fields = { | |||
| "exceptions_count": fields.Integer, | |||
| } | |||
| def build_workflow_run_for_log_model(api_or_ns: Api | Namespace): | |||
| return api_or_ns.model("WorkflowRunForLog", workflow_run_for_log_fields) | |||
| workflow_run_for_list_fields = { | |||
| "id": fields.String, | |||
| "version": fields.String, | |||