浏览代码

Add Authorization checks (#2221)

### What problem does this PR solve?

Add Authorization checks
#2203

### Type of change

- [x] New Feature (non-breaking change which adds functionality)

---------

Co-authored-by: Feiue <10215101452@stu.ecun.edu.cn>
Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>
tags/v0.11.0
LiuHua 1年前
父节点
当前提交
0164856343
没有帐户链接到提交者的电子邮件
共有 4 个文件被更改,包括 79 次插入25 次删除
  1. 5
    0
      api/apps/canvas_app.py
  2. 50
    21
      api/apps/conversation_app.py
  3. 14
    3
      api/apps/dialog_app.py
  4. 10
    1
      api/apps/document_app.py

+ 5
- 0
api/apps/canvas_app.py 查看文件

from flask import request, Response from flask import request, Response
from flask_login import login_required, current_user from flask_login import login_required, current_user
from api.db.services.canvas_service import CanvasTemplateService, UserCanvasService from api.db.services.canvas_service import CanvasTemplateService, UserCanvasService
from api.settings import RetCode
from api.utils import get_uuid from api.utils import get_uuid
from api.utils.api_utils import get_json_result, server_error_response, validate_request, get_data_error_result from api.utils.api_utils import get_json_result, server_error_response, validate_request, get_data_error_result
from agent.canvas import Canvas from agent.canvas import Canvas
@login_required @login_required
def rm(): def rm():
for i in request.json["canvas_ids"]: for i in request.json["canvas_ids"]:
if not UserCanvasService.query(user_id=current_user.id,id=i):
return get_json_result(
data=False, retmsg=f'Only owner of canvas authorized for this operation.',
retcode=RetCode.OPERATING_ERROR)
UserCanvasService.delete_by_id(i) UserCanvasService.delete_by_id(i)
return get_json_result(data=True) return get_json_result(data=True)



+ 50
- 21
api/apps/conversation_app.py 查看文件

# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
import json
from copy import deepcopy from copy import deepcopy

from db.services.user_service import UserTenantService
from flask import request, Response from flask import request, Response
from flask_login import login_required,current_user
from flask_login import login_required, current_user

from api.db import LLMType
from api.db.services.dialog_service import DialogService, ConversationService, chat from api.db.services.dialog_service import DialogService, ConversationService, chat
from api.db.services.llm_service import LLMBundle, TenantService from api.db.services.llm_service import LLMBundle, TenantService
from api.db import LLMType
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
from api.settings import RetCode
from api.utils import get_uuid from api.utils import get_uuid
from api.utils.api_utils import get_json_result from api.utils.api_utils import get_json_result
import json
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request




@manager.route('/set', methods=['POST']) @manager.route('/set', methods=['POST'])
e, conv = ConversationService.get_by_id(conv_id) e, conv = ConversationService.get_by_id(conv_id)
if not e: if not e:
return get_data_error_result(retmsg="Conversation not found!") return get_data_error_result(retmsg="Conversation not found!")
tenants = UserTenantService.query(user_id=current_user.id)
for tenant in tenants:
if DialogService.query(tenant_id=tenant.tenant_id, id=conv.dialog_id):
break
else:
return get_json_result(
data=False, retmsg=f'Only owner of conversation authorized for this operation.',
retcode=RetCode.OPERATING_ERROR)
conv = conv.to_dict() conv = conv.to_dict()
return get_json_result(data=conv) return get_json_result(data=conv)
except Exception as e: except Exception as e:
conv_ids = request.json["conversation_ids"] conv_ids = request.json["conversation_ids"]
try: try:
for cid in conv_ids: for cid in conv_ids:
exist, conv = ConversationService.get_by_id(cid)
if not exist:
return get_data_error_result(retmsg="Conversation not found!")
tenants = UserTenantService.query(user_id=current_user.id)
for tenant in tenants:
if DialogService.query(tenant_id=tenant.tenant_id, id=conv.dialog_id):
break
else:
return get_json_result(
data=False, retmsg=f'Only owner of conversation authorized for this operation.',
retcode=RetCode.OPERATING_ERROR)
ConversationService.delete_by_id(cid) ConversationService.delete_by_id(cid)
return get_json_result(data=True) return get_json_result(data=True)
except Exception as e: except Exception as e:
def list_convsersation(): def list_convsersation():
dialog_id = request.args["dialog_id"] dialog_id = request.args["dialog_id"]
try: try:
if not DialogService.query(tenant_id=current_user.id, id=dialog_id):
return get_json_result(
data=False, retmsg=f'Only owner of dialog authorized for this operation.',
retcode=RetCode.OPERATING_ERROR)
convs = ConversationService.query( convs = ConversationService.query(
dialog_id=dialog_id, dialog_id=dialog_id,
order_by=ConversationService.model.create_time, order_by=ConversationService.model.create_time,


@manager.route('/completion', methods=['POST']) @manager.route('/completion', methods=['POST'])
@login_required @login_required
#@validate_request("conversation_id", "messages")
@validate_request("conversation_id", "messages")
def completion(): def completion():
req = request.json req = request.json
#req = {"conversation_id": "9aaaca4c11d311efa461fa163e197198", "messages": [
# req = {"conversation_id": "9aaaca4c11d311efa461fa163e197198", "messages": [
# {"role": "user", "content": "上海有吗?"} # {"role": "user", "content": "上海有吗?"}
#]}
# ]}
msg = [] msg = []
for m in req["messages"]: for m in req["messages"]:
if m["role"] == "system": if m["role"] == "system":
nonlocal conv, message_id nonlocal conv, message_id
if not conv.reference: if not conv.reference:
conv.reference.append(ans["reference"]) conv.reference.append(ans["reference"])
else: conv.reference[-1] = ans["reference"]
else:
conv.reference[-1] = ans["reference"]
conv.message[-1] = {"role": "assistant", "content": ans["answer"], conv.message[-1] = {"role": "assistant", "content": ans["answer"],
"id": message_id, "prompt": ans.get("prompt", "")} "id": message_id, "prompt": ans.get("prompt", "")}
ans["id"] = message_id ans["id"] = message_id
try: try:
for ans in chat(dia, msg, True, **req): for ans in chat(dia, msg, True, **req):
fillin_conv(ans) fillin_conv(ans)
yield "data:"+json.dumps({"retcode": 0, "retmsg": "", "data": ans}, ensure_ascii=False) + "\n\n"
yield "data:" + json.dumps({"retcode": 0, "retmsg": "", "data": ans}, ensure_ascii=False) + "\n\n"
ConversationService.update_by_id(conv.id, conv.to_dict()) ConversationService.update_by_id(conv.id, conv.to_dict())
except Exception as e: except Exception as e:
yield "data:" + json.dumps({"retcode": 500, "retmsg": str(e), yield "data:" + json.dumps({"retcode": 500, "retmsg": str(e),
"data": {"answer": "**ERROR**: "+str(e), "reference": []}},
"data": {"answer": "**ERROR**: " + str(e), "reference": []}},
ensure_ascii=False) + "\n\n" ensure_ascii=False) + "\n\n"
yield "data:"+json.dumps({"retcode": 0, "retmsg": "", "data": True}, ensure_ascii=False) + "\n\n"
yield "data:" + json.dumps({"retcode": 0, "retmsg": "", "data": True}, ensure_ascii=False) + "\n\n"


if req.get("stream", True): if req.get("stream", True):
resp = Response(stream(), mimetype="text/event-stream") resp = Response(stream(), mimetype="text/event-stream")
def tts(): def tts():
req = request.json req = request.json
text = req["text"] text = req["text"]
tenants = TenantService.get_by_user_id(current_user.id) tenants = TenantService.get_by_user_id(current_user.id)
if not tenants: if not tenants:
return get_data_error_result(retmsg="Tenant not found!") return get_data_error_result(retmsg="Tenant not found!")
tts_id = tenants[0]["tts_id"] tts_id = tenants[0]["tts_id"]
if not tts_id: if not tts_id:
return get_data_error_result(retmsg="No default TTS model is set") return get_data_error_result(retmsg="No default TTS model is set")
tts_mdl = LLMBundle(tenants[0]["tenant_id"], LLMType.TTS, tts_id) tts_mdl = LLMBundle(tenants[0]["tenant_id"], LLMType.TTS, tts_id)

def stream_audio(): def stream_audio():
try: try:
for chunk in tts_mdl.tts(text): for chunk in tts_mdl.tts(text):
yield chunk yield chunk
except Exception as e: except Exception as e:
yield ("data:" + json.dumps({"retcode": 500, "retmsg": str(e), yield ("data:" + json.dumps({"retcode": 500, "retmsg": str(e),
"data": {"answer": "**ERROR**: "+str(e)}},
ensure_ascii=False)).encode('utf-8')
"data": {"answer": "**ERROR**: " + str(e)}},
ensure_ascii=False)).encode('utf-8')


resp = Response(stream_audio(), mimetype="audio/mpeg")
resp = Response(stream_audio(), mimetype="audio/mpeg")
resp.headers.add_header("Cache-Control", "no-cache") resp.headers.add_header("Cache-Control", "no-cache")
resp.headers.add_header("Connection", "keep-alive") resp.headers.add_header("Connection", "keep-alive")
resp.headers.add_header("X-Accel-Buffering", "no") resp.headers.add_header("X-Accel-Buffering", "no")
return resp return resp


@manager.route('/delete_msg', methods=['POST']) @manager.route('/delete_msg', methods=['POST'])
@login_required @login_required
@validate_request("conversation_id", "message_id") @validate_request("conversation_id", "message_id")
for i, msg in enumerate(conv["message"]): for i, msg in enumerate(conv["message"]):
if req["message_id"] != msg.get("id", ""): if req["message_id"] != msg.get("id", ""):
continue continue
assert conv["message"][i+1]["id"] == req["message_id"]
assert conv["message"][i + 1]["id"] == req["message_id"]
conv["message"].pop(i) conv["message"].pop(i)
conv["message"].pop(i) conv["message"].pop(i)
conv["reference"].pop(max(0, i//2-1))
conv["reference"].pop(max(0, i // 2 - 1))
break break


ConversationService.update_by_id(conv["id"], conv) ConversationService.update_by_id(conv["id"], conv)

+ 14
- 3
api/apps/dialog_app.py 查看文件

from api.db.services.dialog_service import DialogService from api.db.services.dialog_service import DialogService
from api.db import StatusEnum from api.db import StatusEnum
from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.user_service import TenantService
from api.db.services.user_service import TenantService, UserTenantService
from api.settings import RetCode
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
from api.utils import get_uuid from api.utils import get_uuid
from api.utils.api_utils import get_json_result from api.utils.api_utils import get_json_result
@validate_request("dialog_ids") @validate_request("dialog_ids")
def rm(): def rm():
req = request.json req = request.json
dialog_list=[]
tenants = UserTenantService.query(user_id=current_user.id)
try: try:
DialogService.update_many_by_id(
[{"id": id, "status": StatusEnum.INVALID.value} for id in req["dialog_ids"]])
for id in req["dialog_ids"]:
for tenant in tenants:
if DialogService.query(tenant_id=tenant.tenant_id, id=id):
break
else:
return get_json_result(
data=False, retmsg=f'Only owner of dialog authorized for this operation.',
retcode=RetCode.OPERATING_ERROR)
dialog_list.append({"id": id,"status":StatusEnum.INVALID.value})
DialogService.update_many_by_id(dialog_list)
return get_json_result(data=True) return get_json_result(data=True)
except Exception as e: except Exception as e:
return server_error_response(e) return server_error_response(e)

+ 10
- 1
api/apps/document_app.py 查看文件

from api.db.services.file_service import FileService from api.db.services.file_service import FileService
from api.db.services.llm_service import LLMBundle from api.db.services.llm_service import LLMBundle
from api.db.services.task_service import TaskService, queue_tasks from api.db.services.task_service import TaskService, queue_tasks
from api.db.services.user_service import TenantService
from api.db.services.user_service import TenantService, UserTenantService
from graphrag.mind_map_extractor import MindMapExtractor from graphrag.mind_map_extractor import MindMapExtractor
from rag.app import naive from rag.app import naive
from rag.nlp import search from rag.nlp import search
if not kb_id: if not kb_id:
return get_json_result( return get_json_result(
data=False, retmsg='Lack of "KB ID"', retcode=RetCode.ARGUMENT_ERROR) data=False, retmsg='Lack of "KB ID"', retcode=RetCode.ARGUMENT_ERROR)
tenants = UserTenantService.query(user_id=current_user.id)
for tenant in tenants:
if KnowledgebaseService.query(
tenant_id=tenant.tenant_id, id=kb_id):
break
else:
return get_json_result(
data=False, retmsg=f'Only owner of knowledgebase authorized for this operation.',
retcode=RetCode.OPERATING_ERROR)
keywords = request.args.get("keywords", "") keywords = request.args.get("keywords", "")


page_number = int(request.args.get("page", 1)) page_number = int(request.args.get("page", 1))

正在加载...
取消
保存