### What problem does this PR solve? SDK for session #1102 ### Type of change - [x] New Feature (non-breaking change which adds functionality) --------- Co-authored-by: Feiue <10215101452@stu.ecun.edu.cn> Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>tags/v0.11.0
| from flask import request | from flask import request | ||||
| from api.db import StatusEnum | from api.db import StatusEnum | ||||
| from api.db.db_models import TenantLLM | |||||
| from api.db.services.dialog_service import DialogService | from api.db.services.dialog_service import DialogService | ||||
| from api.db.services.document_service import DocumentService | |||||
| from api.db.services.knowledgebase_service import KnowledgebaseService | from api.db.services.knowledgebase_service import KnowledgebaseService | ||||
| from api.db.services.llm_service import LLMService, TenantLLMService | |||||
| from api.db.services.user_service import TenantService | from api.db.services.user_service import TenantService | ||||
| from api.settings import RetCode | from api.settings import RetCode | ||||
| from api.utils import get_uuid | from api.utils import get_uuid | ||||
| @token_required | @token_required | ||||
| def save(tenant_id): | def save(tenant_id): | ||||
| req = request.json | req = request.json | ||||
| id = req.get("id") | |||||
| # dataset | # dataset | ||||
| if req.get("knowledgebases") == []: | if req.get("knowledgebases") == []: | ||||
| return get_data_error_result(retmsg="knowledgebases can not be empty list") | return get_data_error_result(retmsg="knowledgebases can not be empty list") | ||||
| return get_data_error_result(retmsg="knowledgebase needs id") | return get_data_error_result(retmsg="knowledgebase needs id") | ||||
| if not KnowledgebaseService.query(id=kb["id"], tenant_id=tenant_id): | if not KnowledgebaseService.query(id=kb["id"], tenant_id=tenant_id): | ||||
| return get_data_error_result(retmsg="you do not own the knowledgebase") | return get_data_error_result(retmsg="you do not own the knowledgebase") | ||||
| if not DocumentService.query(kb_id=kb["id"]): | |||||
| return get_data_error_result(retmsg="There is a invalid knowledgebase") | |||||
| # if not DocumentService.query(kb_id=kb["id"]): | |||||
| # return get_data_error_result(retmsg="There is a invalid knowledgebase") | |||||
| kb_list.append(kb["id"]) | kb_list.append(kb["id"]) | ||||
| req["kb_ids"] = kb_list | req["kb_ids"] = kb_list | ||||
| # llm | # llm | ||||
| req[key] = prompt.pop(key) | req[key] = prompt.pop(key) | ||||
| req["prompt_config"] = req.pop("prompt") | req["prompt_config"] = req.pop("prompt") | ||||
| # create | # create | ||||
| if not id: | |||||
| if "id" not in req: | |||||
| # dataset | # dataset | ||||
| if not kb_list: | if not kb_list: | ||||
| return get_data_error_result(retmsg="knowledgebase is required!") | |||||
| return get_data_error_result(retmsg="knowledgebases are required!") | |||||
| # init | # init | ||||
| req["id"] = get_uuid() | req["id"] = get_uuid() | ||||
| req["description"] = req.get("description", "A helpful Assistant") | req["description"] = req.get("description", "A helpful Assistant") | ||||
| req["top_n"] = req.get("top_n", 6) | req["top_n"] = req.get("top_n", 6) | ||||
| req["top_k"] = req.get("top_k", 1024) | req["top_k"] = req.get("top_k", 1024) | ||||
| req["rerank_id"] = req.get("rerank_id", "") | req["rerank_id"] = req.get("rerank_id", "") | ||||
| req["llm_id"] = req.get("llm_id", tenant.llm_id) | |||||
| if req.get("llm_id"): | |||||
| if not TenantLLMService.query(llm_name=req["llm_id"]): | |||||
| return get_data_error_result(retmsg="the model_name does not exist.") | |||||
| else: | |||||
| req["llm_id"] = tenant.llm_id | |||||
| if not req.get("name"): | if not req.get("name"): | ||||
| return get_data_error_result(retmsg="name is required.") | return get_data_error_result(retmsg="name is required.") | ||||
| if DialogService.query(name=req["name"], tenant_id=tenant_id, status=StatusEnum.VALID.value): | if DialogService.query(name=req["name"], tenant_id=tenant_id, status=StatusEnum.VALID.value): | ||||
| if not DialogService.query(tenant_id=tenant_id, id=req["id"], status=StatusEnum.VALID.value): | if not DialogService.query(tenant_id=tenant_id, id=req["id"], status=StatusEnum.VALID.value): | ||||
| return get_json_result(data=False, retmsg='You do not own the assistant', retcode=RetCode.OPERATING_ERROR) | return get_json_result(data=False, retmsg='You do not own the assistant', retcode=RetCode.OPERATING_ERROR) | ||||
| # prompt | # prompt | ||||
| if not req["id"]: | |||||
| return get_data_error_result(retmsg="id can not be empty") | |||||
| e, res = DialogService.get_by_id(req["id"]) | e, res = DialogService.get_by_id(req["id"]) | ||||
| res = res.to_json() | res = res.to_json() | ||||
| if "llm_id" in req: | |||||
| if not TenantLLMService.query(llm_name=req["llm_id"]): | |||||
| return get_data_error_result(retmsg="the model_name does not exist.") | |||||
| if "name" in req: | if "name" in req: | ||||
| if not req.get("name"): | if not req.get("name"): | ||||
| return get_data_error_result(retmsg="name is not empty.") | return get_data_error_result(retmsg="name is not empty.") | ||||
| if req["name"].lower() != res["name"].lower() \ | if req["name"].lower() != res["name"].lower() \ | ||||
| and len(DialogService.query(name=req["name"], tenant_id=tenant_id,status=StatusEnum.VALID.value)) > 0: | |||||
| return get_data_error_result(retmsg="Duplicated knowledgebase name in updating dataset.") | |||||
| and len( | |||||
| DialogService.query(name=req["name"], tenant_id=tenant_id, status=StatusEnum.VALID.value)) > 0: | |||||
| return get_data_error_result(retmsg="Duplicated assistant name in updating dataset.") | |||||
| if "prompt_config" in req: | if "prompt_config" in req: | ||||
| res["prompt_config"].update(req["prompt_config"]) | res["prompt_config"].update(req["prompt_config"]) | ||||
| for p in res["prompt_config"]["parameters"]: | for p in res["prompt_config"]["parameters"]: | ||||
| if "id" not in req: | if "id" not in req: | ||||
| return get_data_error_result(retmsg="id is required") | return get_data_error_result(retmsg="id is required") | ||||
| id = req['id'] | id = req['id'] | ||||
| if not DialogService.query(tenant_id=tenant_id, id=id,status=StatusEnum.VALID.value): | |||||
| if not DialogService.query(tenant_id=tenant_id, id=id, status=StatusEnum.VALID.value): | |||||
| return get_json_result(data=False, retmsg='you do not own the assistant.', retcode=RetCode.OPERATING_ERROR) | return get_json_result(data=False, retmsg='you do not own the assistant.', retcode=RetCode.OPERATING_ERROR) | ||||
| temp_dict = {"status": StatusEnum.INVALID.value} | temp_dict = {"status": StatusEnum.INVALID.value} | ||||
| req = request.args | req = request.args | ||||
| if "id" in req: | if "id" in req: | ||||
| id = req["id"] | id = req["id"] | ||||
| ass = DialogService.query(tenant_id=tenant_id, id=id,status=StatusEnum.VALID.value) | |||||
| ass = DialogService.query(tenant_id=tenant_id, id=id, status=StatusEnum.VALID.value) | |||||
| if not ass: | if not ass: | ||||
| return get_json_result(data=False, retmsg='You do not own the assistant.', retcode=RetCode.OPERATING_ERROR) | return get_json_result(data=False, retmsg='You do not own the assistant.', retcode=RetCode.OPERATING_ERROR) | ||||
| if "name" in req: | if "name" in req: | ||||
| name = req["name"] | name = req["name"] | ||||
| if ass[0].name != name: | if ass[0].name != name: | ||||
| return get_json_result(data=False, retmsg='name does not match id.', retcode=RetCode.OPERATING_ERROR) | return get_json_result(data=False, retmsg='name does not match id.', retcode=RetCode.OPERATING_ERROR) | ||||
| res=ass[0].to_json() | |||||
| res = ass[0].to_json() | |||||
| else: | else: | ||||
| if "name" in req: | if "name" in req: | ||||
| name = req["name"] | name = req["name"] | ||||
| ass = DialogService.query(name=name, tenant_id=tenant_id,status=StatusEnum.VALID.value) | |||||
| ass = DialogService.query(name=name, tenant_id=tenant_id, status=StatusEnum.VALID.value) | |||||
| if not ass: | if not ass: | ||||
| return get_json_result(data=False, retmsg='You do not own the dataset.',retcode=RetCode.OPERATING_ERROR) | |||||
| res=ass[0].to_json() | |||||
| return get_json_result(data=False, retmsg='You do not own the assistant.', | |||||
| retcode=RetCode.OPERATING_ERROR) | |||||
| res = ass[0].to_json() | |||||
| else: | else: | ||||
| return get_data_error_result(retmsg="At least one of `id` or `name` must be provided.") | return get_data_error_result(retmsg="At least one of `id` or `name` must be provided.") | ||||
| renamed_dict = {} | renamed_dict = {} | ||||
| reverse=True, | reverse=True, | ||||
| order_by=DialogService.model.create_time) | order_by=DialogService.model.create_time) | ||||
| assts = [d.to_dict() for d in assts] | assts = [d.to_dict() for d in assts] | ||||
| list_assts=[] | |||||
| list_assts = [] | |||||
| renamed_dict = {} | renamed_dict = {} | ||||
| key_mapping = {"parameters": "variables", | key_mapping = {"parameters": "variables", | ||||
| "prologue": "opener", | "prologue": "opener", |
| req.update(mapped_keys) | req.update(mapped_keys) | ||||
| if not KnowledgebaseService.save(**req): | if not KnowledgebaseService.save(**req): | ||||
| return get_data_error_result(retmsg="Create dataset error.(Database error)") | return get_data_error_result(retmsg="Create dataset error.(Database error)") | ||||
| renamed_data={} | |||||
| renamed_data = {} | |||||
| e, k = KnowledgebaseService.get_by_id(req["id"]) | e, k = KnowledgebaseService.get_by_id(req["id"]) | ||||
| for key, value in k.to_dict().items(): | for key, value in k.to_dict().items(): | ||||
| new_key = key_mapping.get(key, key) | new_key = key_mapping.get(key, key) | ||||
| data=False, retmsg='You do not own the dataset.', | data=False, retmsg='You do not own the dataset.', | ||||
| retcode=RetCode.OPERATING_ERROR) | retcode=RetCode.OPERATING_ERROR) | ||||
| if not req["id"]: | |||||
| return get_data_error_result( | |||||
| retmsg="id can not be empty.") | |||||
| e, kb = KnowledgebaseService.get_by_id(req["id"]) | e, kb = KnowledgebaseService.get_by_id(req["id"]) | ||||
| if "chunk_count" in req: | if "chunk_count" in req: | ||||
| retmsg="If chunk count is not 0, parse method is not changable.") | retmsg="If chunk count is not 0, parse method is not changable.") | ||||
| req['parser_id'] = req.pop('parse_method') | req['parser_id'] = req.pop('parse_method') | ||||
| if "name" in req: | if "name" in req: | ||||
| req["name"] = req["name"].strip() | |||||
| if req["name"].lower() != kb.name.lower() \ | if req["name"].lower() != kb.name.lower() \ | ||||
| and len(KnowledgebaseService.query(name=req["name"], tenant_id=tenant_id, | and len(KnowledgebaseService.query(name=req["name"], tenant_id=tenant_id, | ||||
| status=StatusEnum.VALID.value)) > 0: | status=StatusEnum.VALID.value)) > 0: |
| # | |||||
| # Copyright 2024 The InfiniFlow Authors. All Rights Reserved. | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # | |||||
| import json | |||||
| from copy import deepcopy | |||||
| from uuid import uuid4 | |||||
| from flask import request, Response | |||||
| from api.db import StatusEnum | |||||
| from api.db.services.dialog_service import DialogService, ConversationService, chat | |||||
| from api.utils import get_uuid | |||||
| from api.utils.api_utils import get_data_error_result | |||||
| from api.utils.api_utils import get_json_result, token_required | |||||
| @manager.route('/save', methods=['POST']) | |||||
| @token_required | |||||
| def set_conversation(tenant_id): | |||||
| req = request.json | |||||
| conv_id = req.get("id") | |||||
| if "messages" in req: | |||||
| req["message"] = req.pop("messages") | |||||
| if req["message"]: | |||||
| for message in req["message"]: | |||||
| if "reference" in message: | |||||
| req["reference"] = message.pop("reference") | |||||
| if "assistant_id" in req: | |||||
| req["dialog_id"] = req.pop("assistant_id") | |||||
| if "id" in req: | |||||
| del req["id"] | |||||
| conv = ConversationService.query(id=conv_id) | |||||
| if not conv: | |||||
| return get_data_error_result(retmsg="Session does not exist") | |||||
| if not DialogService.query(id=conv[0].dialog_id, tenant_id=tenant_id, status=StatusEnum.VALID.value): | |||||
| return get_data_error_result(retmsg="You do not own the session") | |||||
| if req.get("dialog_id"): | |||||
| dia = DialogService.query(tenant_id=tenant_id, id=req["dialog_id"], status=StatusEnum.VALID.value) | |||||
| if not dia: | |||||
| return get_data_error_result(retmsg="You do not own the assistant") | |||||
| if "dialog_id" in req and not req.get("dialog_id"): | |||||
| return get_data_error_result(retmsg="assistant_id can not be empty.") | |||||
| if "name" in req and not req.get("name"): | |||||
| return get_data_error_result(retmsg="name can not be empty.") | |||||
| if "message" in req and not req.get("message"): | |||||
| return get_data_error_result(retmsg="messages can not be empty") | |||||
| if not ConversationService.update_by_id(conv_id, req): | |||||
| return get_data_error_result(retmsg="Session updates error") | |||||
| return get_json_result(data=True) | |||||
| if not req.get("dialog_id"): | |||||
| return get_data_error_result(retmsg="assistant_id is required.") | |||||
| dia = DialogService.query(tenant_id=tenant_id, id=req["dialog_id"], status=StatusEnum.VALID.value) | |||||
| if not dia: | |||||
| return get_data_error_result(retmsg="You do not own the assistant") | |||||
| conv = { | |||||
| "id": get_uuid(), | |||||
| "dialog_id": req["dialog_id"], | |||||
| "name": req.get("name", "New session"), | |||||
| "message": req.get("message", [{"role": "assistant", "content": dia[0].prompt_config["prologue"]}]), | |||||
| "reference": req.get("reference", []) | |||||
| } | |||||
| if not conv.get("name"): | |||||
| return get_data_error_result(retmsg="name can not be empty.") | |||||
| if not conv.get("message"): | |||||
| return get_data_error_result(retmsg="messages can not be empty") | |||||
| ConversationService.save(**conv) | |||||
| e, conv = ConversationService.get_by_id(conv["id"]) | |||||
| if not e: | |||||
| return get_data_error_result(retmsg="Fail to new session!") | |||||
| conv = conv.to_dict() | |||||
| conv["messages"] = conv.pop("message") | |||||
| conv["assistant_id"] = conv.pop("dialog_id") | |||||
| for message in conv["messages"]: | |||||
| message["reference"] = conv.get("reference") | |||||
| del conv["reference"] | |||||
| return get_json_result(data=conv) | |||||
| @manager.route('/completion', methods=['POST']) | |||||
| @token_required | |||||
| def completion(tenant_id): | |||||
| req = request.json | |||||
| # req = {"conversation_id": "9aaaca4c11d311efa461fa163e197198", "messages": [ | |||||
| # {"role": "user", "content": "上海有吗?"} | |||||
| # ]} | |||||
| msg = [] | |||||
| question = { | |||||
| "content": req.get("question"), | |||||
| "role": "user", | |||||
| "id": str(uuid4()) | |||||
| } | |||||
| req["messages"].append(question) | |||||
| for m in req["messages"]: | |||||
| if m["role"] == "system": continue | |||||
| if m["role"] == "assistant" and not msg: continue | |||||
| m["id"] = m.get("id", str(uuid4())) | |||||
| msg.append(m) | |||||
| message_id = msg[-1].get("id") | |||||
| conv = ConversationService.query(id=req["id"]) | |||||
| conv = conv[0] | |||||
| if not conv: | |||||
| return get_data_error_result(retmsg="Session does not exist") | |||||
| if not DialogService.query(id=conv.dialog_id, tenant_id=tenant_id, status=StatusEnum.VALID.value): | |||||
| return get_data_error_result(retmsg="You do not own the session") | |||||
| conv.message = deepcopy(req["messages"]) | |||||
| e, dia = DialogService.get_by_id(conv.dialog_id) | |||||
| if not e: | |||||
| return get_data_error_result(retmsg="Dialog not found!") | |||||
| del req["id"] | |||||
| del req["messages"] | |||||
| if not conv.reference: | |||||
| conv.reference = [] | |||||
| conv.message.append({"role": "assistant", "content": "", "id": message_id}) | |||||
| conv.reference.append({"chunks": [], "doc_aggs": []}) | |||||
| def fillin_conv(ans): | |||||
| nonlocal conv, message_id | |||||
| if not conv.reference: | |||||
| conv.reference.append(ans["reference"]) | |||||
| else: | |||||
| conv.reference[-1] = ans["reference"] | |||||
| conv.message[-1] = {"role": "assistant", "content": ans["answer"], | |||||
| "id": message_id, "prompt": ans.get("prompt", "")} | |||||
| ans["id"] = message_id | |||||
| def stream(): | |||||
| nonlocal dia, msg, req, conv | |||||
| try: | |||||
| for ans in chat(dia, msg, **req): | |||||
| fillin_conv(ans) | |||||
| yield "data:" + json.dumps({"retcode": 0, "retmsg": "", "data": ans}, ensure_ascii=False) + "\n\n" | |||||
| ConversationService.update_by_id(conv.id, conv.to_dict()) | |||||
| except Exception as e: | |||||
| yield "data:" + json.dumps({"retcode": 500, "retmsg": str(e), | |||||
| "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, | |||||
| ensure_ascii=False) + "\n\n" | |||||
| yield "data:" + json.dumps({"retcode": 0, "retmsg": "", "data": True}, ensure_ascii=False) + "\n\n" | |||||
| if req.get("stream", True): | |||||
| resp = Response(stream(), mimetype="text/event-stream") | |||||
| resp.headers.add_header("Cache-control", "no-cache") | |||||
| resp.headers.add_header("Connection", "keep-alive") | |||||
| resp.headers.add_header("X-Accel-Buffering", "no") | |||||
| resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8") | |||||
| return resp | |||||
| else: | |||||
| answer = None | |||||
| for ans in chat(dia, msg, **req): | |||||
| answer = ans | |||||
| fillin_conv(ans) | |||||
| ConversationService.update_by_id(conv.id, conv.to_dict()) | |||||
| break | |||||
| return get_json_result(data=answer) |
| from typing import List | |||||
| from .base import Base | from .base import Base | ||||
| from .session import Session, Message | |||||
| class Assistant(Base): | class Assistant(Base): | ||||
| def __init__(self, rag, res_dict): | def __init__(self, rag, res_dict): | ||||
| self.id="" | |||||
| self.id = "" | |||||
| self.name = "assistant" | self.name = "assistant" | ||||
| self.avatar = "path/to/avatar" | self.avatar = "path/to/avatar" | ||||
| self.knowledgebases = ["kb1"] | self.knowledgebases = ["kb1"] | ||||
| def save(self) -> bool: | def save(self) -> bool: | ||||
| res = self.post('/assistant/save', | res = self.post('/assistant/save', | ||||
| {"id": self.id, "name": self.name, "avatar": self.avatar, "knowledgebases":self.knowledgebases, | |||||
| "llm":self.llm.to_json(),"prompt":self.prompt.to_json() | |||||
| {"id": self.id, "name": self.name, "avatar": self.avatar, "knowledgebases": self.knowledgebases, | |||||
| "llm": self.llm.to_json(), "prompt": self.prompt.to_json() | |||||
| }) | }) | ||||
| res = res.json() | res = res.json() | ||||
| if res.get("retmsg") == "success": return True | if res.get("retmsg") == "success": return True | ||||
| res = res.json() | res = res.json() | ||||
| if res.get("retmsg") == "success": return True | if res.get("retmsg") == "success": return True | ||||
| raise Exception(res["retmsg"]) | raise Exception(res["retmsg"]) | ||||
| def create_session(self, name: str = "New session", messages: List[Message] = [ | |||||
| {"role": "assistant", "reference": [], | |||||
| "content": "您好,我是您的助手小樱,长得可爱又善良,can I help you?"}]) -> Session: | |||||
| res = self.post("/session/save", {"name": name, "messages": messages, "assistant_id": self.id, }) | |||||
| res = res.json() | |||||
| if res.get("retmsg") == "success": | |||||
| return Session(self.rag, res['data']) | |||||
| raise Exception(res["retmsg"]) | |||||
| def get_prologue(self): | |||||
| return self.prompt.opener |
| import json | |||||
| from .base import Base | |||||
| class Session(Base): | |||||
| def __init__(self, rag, res_dict): | |||||
| self.id = None | |||||
| self.name = "New session" | |||||
| self.messages = [{"role": "assistant", "content": "Hi! I am your assistant,can I help you?"}] | |||||
| self.assistant_id = None | |||||
| super().__init__(rag, res_dict) | |||||
| def chat(self, question: str, stream: bool = False): | |||||
| res = self.post("/session/completion", | |||||
| {"id": self.id, "question": question, "stream": stream, "messages": self.messages}) | |||||
| res = res.text | |||||
| response_lines = res.splitlines() | |||||
| message_list = [] | |||||
| for line in response_lines: | |||||
| if line.startswith("data:"): | |||||
| json_data = json.loads(line[5:]) | |||||
| if json_data["data"] != True: | |||||
| answer = json_data["data"]["answer"] | |||||
| reference = json_data["data"]["reference"] | |||||
| temp_dict = { | |||||
| "content": answer, | |||||
| "role": "assistant", | |||||
| "reference": reference | |||||
| } | |||||
| message = Message(self.rag, temp_dict) | |||||
| message_list.append(message) | |||||
| return message_list | |||||
| def save(self): | |||||
| res = self.post("/session/save", | |||||
| {"id": self.id, "dialog_id": self.assistant_id, "name": self.name, "message": self.messages}) | |||||
| res = res.json() | |||||
| if res.get("retmsg") == "success": return True | |||||
| raise Exception(res.get("retmsg")) | |||||
| class Message(Base): | |||||
| def __init__(self, rag, res_dict): | |||||
| self.content = "您好,我是您的助手小樱,长得可爱又善良,can I help you?" | |||||
| self.reference = [] | |||||
| self.role = "assistant" | |||||
| self.prompt=None | |||||
| super().__init__(rag, res_dict) | |||||
| class Chunk(Base): | |||||
| def __init__(self, rag, res_dict): | |||||
| self.id = None | |||||
| self.content = None | |||||
| self.document_id = None | |||||
| self.document_name = None | |||||
| self.knowledgebase_id = None | |||||
| self.image_id = None | |||||
| self.similarity = None | |||||
| self.vector_similarity = None | |||||
| self.term_similarity = None | |||||
| self.positions = None | |||||
| super().__init__(rag, res_dict) |
| import requests | import requests | ||||
| from .modules.chat_assistant import Assistant | from .modules.chat_assistant import Assistant | ||||
| from .modules.dataset import DataSet | from .modules.dataset import DataSet | ||||
| datasets.append(dataset.to_json()) | datasets.append(dataset.to_json()) | ||||
| if llm is None: | if llm is None: | ||||
| llm = Assistant.LLM(self, {"model_name": "deepseek-chat", | |||||
| llm = Assistant.LLM(self, {"model_name": None, | |||||
| "temperature": 0.1, | "temperature": 0.1, | ||||
| "top_p": 0.3, | "top_p": 0.3, | ||||
| "presence_penalty": 0.4, | "presence_penalty": 0.4, | ||||
| for data in res['data']: | for data in res['data']: | ||||
| result_list.append(Assistant(self, data)) | result_list.append(Assistant(self, data)) | ||||
| return result_list | return result_list | ||||
| raise Exception(res["retmsg"]) | |||||
| raise Exception(res["retmsg"]) |
| Test creating an assistant with success | Test creating an assistant with success | ||||
| """ | """ | ||||
| rag = RAGFlow(API_KEY, HOST_ADDRESS) | rag = RAGFlow(API_KEY, HOST_ADDRESS) | ||||
| kb = rag.get_dataset(name="God") | |||||
| assistant = rag.create_assistant("God",knowledgebases=[kb]) | |||||
| kb = rag.create_dataset(name="test_create_assistant") | |||||
| assistant = rag.create_assistant("test_create", knowledgebases=[kb]) | |||||
| if isinstance(assistant, Assistant): | if isinstance(assistant, Assistant): | ||||
| assert assistant.name == "God", "Name does not match." | |||||
| assert assistant.name == "test_create", "Name does not match." | |||||
| else: | else: | ||||
| assert False, f"Failed to create assistant, error: {assistant}" | assert False, f"Failed to create assistant, error: {assistant}" | ||||
| Test updating an assistant with success. | Test updating an assistant with success. | ||||
| """ | """ | ||||
| rag = RAGFlow(API_KEY, HOST_ADDRESS) | rag = RAGFlow(API_KEY, HOST_ADDRESS) | ||||
| kb = rag.get_dataset(name="God") | |||||
| assistant = rag.create_assistant("ABC",knowledgebases=[kb]) | |||||
| kb = rag.create_dataset(name="test_update_assistant") | |||||
| assistant = rag.create_assistant("test_update", knowledgebases=[kb]) | |||||
| if isinstance(assistant, Assistant): | if isinstance(assistant, Assistant): | ||||
| assert assistant.name == "ABC", "Name does not match." | |||||
| assistant.name = 'DEF' | |||||
| assert assistant.name == "test_update", "Name does not match." | |||||
| assistant.name = 'new_assistant' | |||||
| res = assistant.save() | res = assistant.save() | ||||
| assert res is True, f"Failed to update assistant, error: {res}" | assert res is True, f"Failed to update assistant, error: {res}" | ||||
| else: | else: | ||||
| Test deleting an assistant with success | Test deleting an assistant with success | ||||
| """ | """ | ||||
| rag = RAGFlow(API_KEY, HOST_ADDRESS) | rag = RAGFlow(API_KEY, HOST_ADDRESS) | ||||
| kb = rag.get_dataset(name="God") | |||||
| assistant = rag.create_assistant("MA",knowledgebases=[kb]) | |||||
| kb = rag.create_dataset(name="test_delete_assistant") | |||||
| assistant = rag.create_assistant("test_delete", knowledgebases=[kb]) | |||||
| if isinstance(assistant, Assistant): | if isinstance(assistant, Assistant): | ||||
| assert assistant.name == "MA", "Name does not match." | |||||
| assert assistant.name == "test_delete", "Name does not match." | |||||
| res = assistant.delete() | res = assistant.delete() | ||||
| assert res is True, f"Failed to delete assistant, error: {res}" | assert res is True, f"Failed to delete assistant, error: {res}" | ||||
| else: | else: | ||||
| Test getting an assistant's detail with success | Test getting an assistant's detail with success | ||||
| """ | """ | ||||
| rag = RAGFlow(API_KEY, HOST_ADDRESS) | rag = RAGFlow(API_KEY, HOST_ADDRESS) | ||||
| assistant = rag.get_assistant(name="God") | |||||
| kb = rag.create_dataset(name="test_get_assistant") | |||||
| rag.create_assistant("test_get_assistant", knowledgebases=[kb]) | |||||
| assistant = rag.get_assistant(name="test_get_assistant") | |||||
| assert isinstance(assistant, Assistant), f"Failed to get assistant, error: {assistant}." | assert isinstance(assistant, Assistant), f"Failed to get assistant, error: {assistant}." | ||||
| assert assistant.name == "God", "Name does not match" | |||||
| assert assistant.name == "test_get_assistant", "Name does not match" |
| from ragflow import RAGFlow | |||||
| from common import API_KEY, HOST_ADDRESS | |||||
| class TestChatSession: | |||||
| def test_create_session(self): | |||||
| rag = RAGFlow(API_KEY, HOST_ADDRESS) | |||||
| kb = rag.create_dataset(name="test_create_session") | |||||
| assistant = rag.create_assistant(name="test_create_session", knowledgebases=[kb]) | |||||
| session = assistant.create_session() | |||||
| assert assistant is not None, "Failed to get the assistant." | |||||
| assert session is not None, "Failed to create a session." | |||||
| def test_create_chat_with_success(self): | |||||
| rag = RAGFlow(API_KEY, HOST_ADDRESS) | |||||
| kb = rag.create_dataset(name="test_create_chat") | |||||
| assistant = rag.create_assistant(name="test_create_chat", knowledgebases=[kb]) | |||||
| session = assistant.create_session() | |||||
| assert session is not None, "Failed to create a session." | |||||
| prologue = assistant.get_prologue() | |||||
| assert isinstance(prologue, str), "Prologue is not a string." | |||||
| assert len(prologue) > 0, "Prologue is empty." | |||||
| question = "What is AI" | |||||
| ans = session.chat(question, stream=True) | |||||
| response = ans[-1].content | |||||
| assert len(response) > 0, "Assistant did not return any response." |