소스 검색

fix: missing default user for APP service api (#2606)

tags/0.5.7
takatost 1 년 전
부모
커밋
0828873b52
No account linked to committer's email address

+ 0
- 27
api/controllers/service_api/app/__init__.py 파일 보기

@@ -1,27 +0,0 @@
from extensions.ext_database import db
from models.model import EndUser


def create_or_update_end_user_for_user_id(app_model, user_id):
"""
Create or update session terminal based on user ID.
"""
end_user = db.session.query(EndUser) \
.filter(
EndUser.tenant_id == app_model.tenant_id,
EndUser.session_id == user_id,
EndUser.type == 'service_api'
).first()

if end_user is None:
end_user = EndUser(
tenant_id=app_model.tenant_id,
app_id=app_model.id,
type='service_api',
is_anonymous=True,
session_id=user_id
)
db.session.add(end_user)
db.session.commit()

return end_user

+ 8
- 6
api/controllers/service_api/app/app.py 파일 보기

@@ -1,16 +1,16 @@
import json

from flask import current_app
from flask_restful import fields, marshal_with
from flask_restful import fields, marshal_with, Resource

from controllers.service_api import api
from controllers.service_api.wraps import AppApiResource
from controllers.service_api.wraps import validate_app_token
from extensions.ext_database import db
from models.model import App, AppModelConfig
from models.tools import ApiToolProvider


class AppParameterApi(AppApiResource):
class AppParameterApi(Resource):
"""Resource for app variables."""

variable_fields = {
@@ -42,8 +42,9 @@ class AppParameterApi(AppApiResource):
'system_parameters': fields.Nested(system_parameters_fields)
}

@validate_app_token
@marshal_with(parameters_fields)
def get(self, app_model: App, end_user):
def get(self, app_model: App):
"""Retrieve app parameters."""
app_model_config = app_model.app_model_config

@@ -64,8 +65,9 @@ class AppParameterApi(AppApiResource):
}
}

class AppMetaApi(AppApiResource):
def get(self, app_model: App, end_user):
class AppMetaApi(Resource):
@validate_app_token
def get(self, app_model: App):
"""Get app meta"""
app_model_config: AppModelConfig = app_model.app_model_config


+ 10
- 9
api/controllers/service_api/app/audio.py 파일 보기

@@ -1,7 +1,7 @@
import logging

from flask import request
from flask_restful import reqparse
from flask_restful import Resource, reqparse
from werkzeug.exceptions import InternalServerError

import services
@@ -17,10 +17,10 @@ from controllers.service_api.app.error import (
ProviderQuotaExceededError,
UnsupportedAudioTypeError,
)
from controllers.service_api.wraps import AppApiResource
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.model_runtime.errors.invoke import InvokeError
from models.model import App, AppModelConfig
from models.model import App, AppModelConfig, EndUser
from services.audio_service import AudioService
from services.errors.audio import (
AudioTooLargeServiceError,
@@ -30,8 +30,9 @@ from services.errors.audio import (
)


class AudioApi(AppApiResource):
def post(self, app_model: App, end_user):
class AudioApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.FORM))
def post(self, app_model: App, end_user: EndUser):
app_model_config: AppModelConfig = app_model.app_model_config

if not app_model_config.speech_to_text_dict['enabled']:
@@ -73,11 +74,11 @@ class AudioApi(AppApiResource):
raise InternalServerError()


class TextApi(AppApiResource):
def post(self, app_model: App, end_user):
class TextApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
def post(self, app_model: App, end_user: EndUser):
parser = reqparse.RequestParser()
parser.add_argument('text', type=str, required=True, nullable=False, location='json')
parser.add_argument('user', type=str, required=True, nullable=False, location='json')
parser.add_argument('streaming', type=bool, required=False, nullable=False, location='json')
args = parser.parse_args()

@@ -85,7 +86,7 @@ class TextApi(AppApiResource):
response = AudioService.transcript_tts(
tenant_id=app_model.tenant_id,
text=args['text'],
end_user=args['user'],
end_user=end_user,
voice=app_model.app_model_config.text_to_speech_dict.get('voice'),
streaming=args['streaming']
)

+ 15
- 41
api/controllers/service_api/app/completion.py 파일 보기

@@ -4,12 +4,11 @@ from collections.abc import Generator
from typing import Union

from flask import Response, stream_with_context
from flask_restful import reqparse
from flask_restful import Resource, reqparse
from werkzeug.exceptions import InternalServerError, NotFound

import services
from controllers.service_api import api
from controllers.service_api.app import create_or_update_end_user_for_user_id
from controllers.service_api.app.error import (
AppUnavailableError,
CompletionRequestError,
@@ -19,17 +18,19 @@ from controllers.service_api.app.error import (
ProviderNotInitializeError,
ProviderQuotaExceededError,
)
from controllers.service_api.wraps import AppApiResource
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from core.application_queue_manager import ApplicationQueueManager
from core.entities.application_entities import InvokeFrom
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.model_runtime.errors.invoke import InvokeError
from libs.helper import uuid_value
from models.model import App, EndUser
from services.completion_service import CompletionService


class CompletionApi(AppApiResource):
def post(self, app_model, end_user):
class CompletionApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
def post(self, app_model: App, end_user: EndUser):
if app_model.mode != 'completion':
raise AppUnavailableError()

@@ -38,16 +39,12 @@ class CompletionApi(AppApiResource):
parser.add_argument('query', type=str, location='json', default='')
parser.add_argument('files', type=list, required=False, location='json')
parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
parser.add_argument('user', required=True, nullable=False, type=str, location='json')
parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json')

args = parser.parse_args()

streaming = args['response_mode'] == 'streaming'

if end_user is None and args['user'] is not None:
end_user = create_or_update_end_user_for_user_id(app_model, args['user'])

args['auto_generate_name'] = False

try:
@@ -82,29 +79,20 @@ class CompletionApi(AppApiResource):
raise InternalServerError()


class CompletionStopApi(AppApiResource):
def post(self, app_model, end_user, task_id):
class CompletionStopApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
def post(self, app_model: App, end_user: EndUser, task_id):
if app_model.mode != 'completion':
raise AppUnavailableError()

if end_user is None:
parser = reqparse.RequestParser()
parser.add_argument('user', required=True, nullable=False, type=str, location='json')
args = parser.parse_args()

user = args.get('user')
if user is not None:
end_user = create_or_update_end_user_for_user_id(app_model, user)
else:
raise ValueError("arg user muse be input.")

ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id)

return {'result': 'success'}, 200


class ChatApi(AppApiResource):
def post(self, app_model, end_user):
class ChatApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
def post(self, app_model: App, end_user: EndUser):
if app_model.mode != 'chat':
raise NotChatAppError()

@@ -114,7 +102,6 @@ class ChatApi(AppApiResource):
parser.add_argument('files', type=list, required=False, location='json')
parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
parser.add_argument('conversation_id', type=uuid_value, location='json')
parser.add_argument('user', type=str, required=True, nullable=False, location='json')
parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json')
parser.add_argument('auto_generate_name', type=bool, required=False, default=True, location='json')

@@ -122,9 +109,6 @@ class ChatApi(AppApiResource):

streaming = args['response_mode'] == 'streaming'

if end_user is None and args['user'] is not None:
end_user = create_or_update_end_user_for_user_id(app_model, args['user'])

try:
response = CompletionService.completion(
app_model=app_model,
@@ -157,22 +141,12 @@ class ChatApi(AppApiResource):
raise InternalServerError()


class ChatStopApi(AppApiResource):
def post(self, app_model, end_user, task_id):
class ChatStopApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
def post(self, app_model: App, end_user: EndUser, task_id):
if app_model.mode != 'chat':
raise NotChatAppError()

if end_user is None:
parser = reqparse.RequestParser()
parser.add_argument('user', required=True, nullable=False, type=str, location='json')
args = parser.parse_args()

user = args.get('user')
if user is not None:
end_user = create_or_update_end_user_for_user_id(app_model, user)
else:
raise ValueError("arg user muse be input.")

ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id)

return {'result': 'success'}, 200

+ 12
- 23
api/controllers/service_api/app/conversation.py 파일 보기

@@ -1,52 +1,44 @@
from flask import request
from flask_restful import marshal_with, reqparse
from flask_restful import Resource, marshal_with, reqparse
from flask_restful.inputs import int_range
from werkzeug.exceptions import NotFound

import services
from controllers.service_api import api
from controllers.service_api.app import create_or_update_end_user_for_user_id
from controllers.service_api.app.error import NotChatAppError
from controllers.service_api.wraps import AppApiResource
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields
from libs.helper import uuid_value
from models.model import App, EndUser
from services.conversation_service import ConversationService


class ConversationApi(AppApiResource):
class ConversationApi(Resource):

@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
@marshal_with(conversation_infinite_scroll_pagination_fields)
def get(self, app_model, end_user):
def get(self, app_model: App, end_user: EndUser):
if app_model.mode != 'chat':
raise NotChatAppError()

parser = reqparse.RequestParser()
parser.add_argument('last_id', type=uuid_value, location='args')
parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args')
parser.add_argument('user', type=str, location='args')
args = parser.parse_args()

if end_user is None and args['user'] is not None:
end_user = create_or_update_end_user_for_user_id(app_model, args['user'])

try:
return ConversationService.pagination_by_last_id(app_model, end_user, args['last_id'], args['limit'])
except services.errors.conversation.LastConversationNotExistsError:
raise NotFound("Last Conversation Not Exists.")

class ConversationDetailApi(AppApiResource):
class ConversationDetailApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
@marshal_with(simple_conversation_fields)
def delete(self, app_model, end_user, c_id):
def delete(self, app_model: App, end_user: EndUser, c_id):
if app_model.mode != 'chat':
raise NotChatAppError()

conversation_id = str(c_id)

user = request.get_json().get('user')

if end_user is None and user is not None:
end_user = create_or_update_end_user_for_user_id(app_model, user)

try:
ConversationService.delete(app_model, conversation_id, end_user)
except services.errors.conversation.ConversationNotExistsError:
@@ -54,10 +46,11 @@ class ConversationDetailApi(AppApiResource):
return {"result": "success"}, 204


class ConversationRenameApi(AppApiResource):
class ConversationRenameApi(Resource):

@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
@marshal_with(simple_conversation_fields)
def post(self, app_model, end_user, c_id):
def post(self, app_model: App, end_user: EndUser, c_id):
if app_model.mode != 'chat':
raise NotChatAppError()

@@ -65,13 +58,9 @@ class ConversationRenameApi(AppApiResource):

parser = reqparse.RequestParser()
parser.add_argument('name', type=str, required=False, location='json')
parser.add_argument('user', type=str, location='json')
parser.add_argument('auto_generate', type=bool, required=False, default=False, location='json')
args = parser.parse_args()

if end_user is None and args['user'] is not None:
end_user = create_or_update_end_user_for_user_id(app_model, args['user'])

try:
return ConversationService.rename(
app_model,

+ 6
- 9
api/controllers/service_api/app/file.py 파일 보기

@@ -1,30 +1,27 @@
from flask import request
from flask_restful import marshal_with
from flask_restful import Resource, marshal_with

import services
from controllers.service_api import api
from controllers.service_api.app import create_or_update_end_user_for_user_id
from controllers.service_api.app.error import (
FileTooLargeError,
NoFileUploadedError,
TooManyFilesError,
UnsupportedFileTypeError,
)
from controllers.service_api.wraps import AppApiResource
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from fields.file_fields import file_fields
from models.model import App, EndUser
from services.file_service import FileService


class FileApi(AppApiResource):
class FileApi(Resource):

@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.FORM))
@marshal_with(file_fields)
def post(self, app_model, end_user):
def post(self, app_model: App, end_user: EndUser):

file = request.files['file']
user_args = request.form.get('user')

if end_user is None and user_args is not None:
end_user = create_or_update_end_user_for_user_id(app_model, user_args)

# check file
if 'file' not in request.files:

+ 14
- 34
api/controllers/service_api/app/message.py 파일 보기

@@ -1,20 +1,18 @@
from flask_restful import fields, marshal_with, reqparse
from flask_restful import Resource, fields, marshal_with, reqparse
from flask_restful.inputs import int_range
from werkzeug.exceptions import NotFound

import services
from controllers.service_api import api
from controllers.service_api.app import create_or_update_end_user_for_user_id
from controllers.service_api.app.error import NotChatAppError
from controllers.service_api.wraps import AppApiResource
from extensions.ext_database import db
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from fields.conversation_fields import message_file_fields
from libs.helper import TimestampField, uuid_value
from models.model import EndUser, Message
from models.model import App, EndUser
from services.message_service import MessageService


class MessageListApi(AppApiResource):
class MessageListApi(Resource):
feedback_fields = {
'rating': fields.String
}
@@ -70,8 +68,9 @@ class MessageListApi(AppApiResource):
'data': fields.List(fields.Nested(message_fields))
}

@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
@marshal_with(message_infinite_scroll_pagination_fields)
def get(self, app_model, end_user):
def get(self, app_model: App, end_user: EndUser):
if app_model.mode != 'chat':
raise NotChatAppError()

@@ -79,12 +78,8 @@ class MessageListApi(AppApiResource):
parser.add_argument('conversation_id', required=True, type=uuid_value, location='args')
parser.add_argument('first_id', type=uuid_value, location='args')
parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args')
parser.add_argument('user', type=str, location='args')
args = parser.parse_args()

if end_user is None and args['user'] is not None:
end_user = create_or_update_end_user_for_user_id(app_model, args['user'])

try:
return MessageService.pagination_by_first_id(app_model, end_user,
args['conversation_id'], args['first_id'], args['limit'])
@@ -94,18 +89,15 @@ class MessageListApi(AppApiResource):
raise NotFound("First Message Not Exists.")


class MessageFeedbackApi(AppApiResource):
def post(self, app_model, end_user, message_id):
class MessageFeedbackApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
def post(self, app_model: App, end_user: EndUser, message_id):
message_id = str(message_id)

parser = reqparse.RequestParser()
parser.add_argument('rating', type=str, choices=['like', 'dislike', None], location='json')
parser.add_argument('user', type=str, location='json')
args = parser.parse_args()

if end_user is None and args['user'] is not None:
end_user = create_or_update_end_user_for_user_id(app_model, args['user'])

try:
MessageService.create_feedback(app_model, message_id, end_user, args['rating'])
except services.errors.message.MessageNotExistsError:
@@ -114,29 +106,17 @@ class MessageFeedbackApi(AppApiResource):
return {'result': 'success'}


class MessageSuggestedApi(AppApiResource):
def get(self, app_model, end_user, message_id):
class MessageSuggestedApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
def get(self, app_model: App, end_user: EndUser, message_id):
message_id = str(message_id)
if app_model.mode != 'chat':
raise NotChatAppError()

try:
message = db.session.query(Message).filter(
Message.id == message_id,
Message.app_id == app_model.id,
).first()

if end_user is None and message.from_end_user_id is not None:
user = db.session.query(EndUser) \
.filter(
EndUser.tenant_id == app_model.tenant_id,
EndUser.id == message.from_end_user_id,
EndUser.type == 'service_api'
).first()
else:
user = end_user
questions = MessageService.get_suggested_questions_after_answer(
app_model=app_model,
user=user,
user=end_user,
message_id=message_id,
check_enabled=False
)

+ 76
- 14
api/controllers/service_api/wraps.py 파일 보기

@@ -1,22 +1,40 @@
from collections.abc import Callable
from datetime import datetime
from enum import Enum
from functools import wraps
from typing import Optional

from flask import current_app, request
from flask_login import user_logged_in
from flask_restful import Resource
from pydantic import BaseModel
from werkzeug.exceptions import NotFound, Unauthorized

from extensions.ext_database import db
from libs.login import _get_user
from models.account import Account, Tenant, TenantAccountJoin
from models.model import ApiToken, App
from models.model import ApiToken, App, EndUser
from services.feature_service import FeatureService


def validate_app_token(view=None):
def decorator(view):
@wraps(view)
def decorated(*args, **kwargs):
class WhereisUserArg(Enum):
"""
Enum for whereis_user_arg.
"""
QUERY = 'query'
JSON = 'json'
FORM = 'form'


class FetchUserArg(BaseModel):
fetch_from: WhereisUserArg
required: bool = False


def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optional[FetchUserArg] = None):
def decorator(view_func):
@wraps(view_func)
def decorated_view(*args, **kwargs):
api_token = validate_and_get_api_token('app')

app_model = db.session.query(App).filter(App.id == api_token.app_id).first()
@@ -29,15 +47,34 @@ def validate_app_token(view=None):
if not app_model.enable_api:
raise NotFound()

return view(app_model, None, *args, **kwargs)
return decorated
kwargs['app_model'] = app_model

if view:
return decorator(view)
if not fetch_user_arg:
# use default-user
user_id = None
else:
if fetch_user_arg.fetch_from == WhereisUserArg.QUERY:
user_id = request.args.get('user')
elif fetch_user_arg.fetch_from == WhereisUserArg.JSON:
user_id = request.get_json().get('user')
elif fetch_user_arg.fetch_from == WhereisUserArg.FORM:
user_id = request.form.get('user')
else:
# use default-user
user_id = None

# if view is None, it means that the decorator is used without parentheses
# use the decorator as a function for method_decorators
return decorator
if not user_id and fetch_user_arg.required:
raise ValueError("Arg user must be provided.")

kwargs['end_user'] = create_or_update_end_user_for_user_id(app_model, user_id)

return view_func(*args, **kwargs)
return decorated_view

if view is None:
return decorator
else:
return decorator(view)


def cloud_edition_billing_resource_check(resource: str,
@@ -128,8 +165,33 @@ def validate_and_get_api_token(scope=None):
return api_token


class AppApiResource(Resource):
method_decorators = [validate_app_token]
def create_or_update_end_user_for_user_id(app_model: App, user_id: Optional[str] = None) -> EndUser:
"""
Create or update session terminal based on user ID.
"""
if not user_id:
user_id = 'DEFAULT-USER'

end_user = db.session.query(EndUser) \
.filter(
EndUser.tenant_id == app_model.tenant_id,
EndUser.app_id == app_model.id,
EndUser.session_id == user_id,
EndUser.type == 'service_api'
).first()

if end_user is None:
end_user = EndUser(
tenant_id=app_model.tenant_id,
app_id=app_model.id,
type='service_api',
is_anonymous=True if user_id == 'DEFAULT-USER' else False,
session_id=user_id
)
db.session.add(end_user)
db.session.commit()

return end_user


class DatasetApiResource(Resource):

Loading…
취소
저장