### What problem does this PR solve? SDK for session #1102 ### Type of change - [x] New Feature (non-breaking change which adds functionality) --------- Co-authored-by: Feiue <10215101452@stu.ecun.edu.cn> Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>tags/v0.11.0
| @@ -16,9 +16,10 @@ | |||
| from flask import request | |||
| from api.db import StatusEnum | |||
| from api.db.db_models import TenantLLM | |||
| from api.db.services.dialog_service import DialogService | |||
| from api.db.services.document_service import DocumentService | |||
| from api.db.services.knowledgebase_service import KnowledgebaseService | |||
| from api.db.services.llm_service import LLMService, TenantLLMService | |||
| from api.db.services.user_service import TenantService | |||
| from api.settings import RetCode | |||
| from api.utils import get_uuid | |||
| @@ -30,7 +31,6 @@ from api.utils.api_utils import get_json_result | |||
| @token_required | |||
| def save(tenant_id): | |||
| req = request.json | |||
| id = req.get("id") | |||
| # dataset | |||
| if req.get("knowledgebases") == []: | |||
| return get_data_error_result(retmsg="knowledgebases can not be empty list") | |||
| @@ -41,8 +41,8 @@ def save(tenant_id): | |||
| return get_data_error_result(retmsg="knowledgebase needs id") | |||
| if not KnowledgebaseService.query(id=kb["id"], tenant_id=tenant_id): | |||
| return get_data_error_result(retmsg="you do not own the knowledgebase") | |||
| if not DocumentService.query(kb_id=kb["id"]): | |||
| return get_data_error_result(retmsg="There is a invalid knowledgebase") | |||
| # if not DocumentService.query(kb_id=kb["id"]): | |||
| # return get_data_error_result(retmsg="There is a invalid knowledgebase") | |||
| kb_list.append(kb["id"]) | |||
| req["kb_ids"] = kb_list | |||
| # llm | |||
| @@ -72,10 +72,10 @@ def save(tenant_id): | |||
| req[key] = prompt.pop(key) | |||
| req["prompt_config"] = req.pop("prompt") | |||
| # create | |||
| if not id: | |||
| if "id" not in req: | |||
| # dataset | |||
| if not kb_list: | |||
| return get_data_error_result(retmsg="knowledgebase is required!") | |||
| return get_data_error_result(retmsg="knowledgebases are required!") | |||
| # init | |||
| req["id"] = get_uuid() | |||
| req["description"] = req.get("description", "A helpful Assistant") | |||
| @@ -83,7 +83,11 @@ def save(tenant_id): | |||
| req["top_n"] = req.get("top_n", 6) | |||
| req["top_k"] = req.get("top_k", 1024) | |||
| req["rerank_id"] = req.get("rerank_id", "") | |||
| req["llm_id"] = req.get("llm_id", tenant.llm_id) | |||
| if req.get("llm_id"): | |||
| if not TenantLLMService.query(llm_name=req["llm_id"]): | |||
| return get_data_error_result(retmsg="the model_name does not exist.") | |||
| else: | |||
| req["llm_id"] = tenant.llm_id | |||
| if not req.get("name"): | |||
| return get_data_error_result(retmsg="name is required.") | |||
| if DialogService.query(name=req["name"], tenant_id=tenant_id, status=StatusEnum.VALID.value): | |||
| @@ -149,14 +153,20 @@ def save(tenant_id): | |||
| if not DialogService.query(tenant_id=tenant_id, id=req["id"], status=StatusEnum.VALID.value): | |||
| return get_json_result(data=False, retmsg='You do not own the assistant', retcode=RetCode.OPERATING_ERROR) | |||
| # prompt | |||
| if not req["id"]: | |||
| return get_data_error_result(retmsg="id can not be empty") | |||
| e, res = DialogService.get_by_id(req["id"]) | |||
| res = res.to_json() | |||
| if "llm_id" in req: | |||
| if not TenantLLMService.query(llm_name=req["llm_id"]): | |||
| return get_data_error_result(retmsg="the model_name does not exist.") | |||
| if "name" in req: | |||
| if not req.get("name"): | |||
| return get_data_error_result(retmsg="name is not empty.") | |||
| if req["name"].lower() != res["name"].lower() \ | |||
| and len(DialogService.query(name=req["name"], tenant_id=tenant_id,status=StatusEnum.VALID.value)) > 0: | |||
| return get_data_error_result(retmsg="Duplicated knowledgebase name in updating dataset.") | |||
| and len( | |||
| DialogService.query(name=req["name"], tenant_id=tenant_id, status=StatusEnum.VALID.value)) > 0: | |||
| return get_data_error_result(retmsg="Duplicated assistant name in updating dataset.") | |||
| if "prompt_config" in req: | |||
| res["prompt_config"].update(req["prompt_config"]) | |||
| for p in res["prompt_config"]["parameters"]: | |||
| @@ -186,7 +196,7 @@ def delete(tenant_id): | |||
| if "id" not in req: | |||
| return get_data_error_result(retmsg="id is required") | |||
| id = req['id'] | |||
| if not DialogService.query(tenant_id=tenant_id, id=id,status=StatusEnum.VALID.value): | |||
| if not DialogService.query(tenant_id=tenant_id, id=id, status=StatusEnum.VALID.value): | |||
| return get_json_result(data=False, retmsg='you do not own the assistant.', retcode=RetCode.OPERATING_ERROR) | |||
| temp_dict = {"status": StatusEnum.INVALID.value} | |||
| @@ -200,21 +210,22 @@ def get(tenant_id): | |||
| req = request.args | |||
| if "id" in req: | |||
| id = req["id"] | |||
| ass = DialogService.query(tenant_id=tenant_id, id=id,status=StatusEnum.VALID.value) | |||
| ass = DialogService.query(tenant_id=tenant_id, id=id, status=StatusEnum.VALID.value) | |||
| if not ass: | |||
| return get_json_result(data=False, retmsg='You do not own the assistant.', retcode=RetCode.OPERATING_ERROR) | |||
| if "name" in req: | |||
| name = req["name"] | |||
| if ass[0].name != name: | |||
| return get_json_result(data=False, retmsg='name does not match id.', retcode=RetCode.OPERATING_ERROR) | |||
| res=ass[0].to_json() | |||
| res = ass[0].to_json() | |||
| else: | |||
| if "name" in req: | |||
| name = req["name"] | |||
| ass = DialogService.query(name=name, tenant_id=tenant_id,status=StatusEnum.VALID.value) | |||
| ass = DialogService.query(name=name, tenant_id=tenant_id, status=StatusEnum.VALID.value) | |||
| if not ass: | |||
| return get_json_result(data=False, retmsg='You do not own the dataset.',retcode=RetCode.OPERATING_ERROR) | |||
| res=ass[0].to_json() | |||
| return get_json_result(data=False, retmsg='You do not own the assistant.', | |||
| retcode=RetCode.OPERATING_ERROR) | |||
| res = ass[0].to_json() | |||
| else: | |||
| return get_data_error_result(retmsg="At least one of `id` or `name` must be provided.") | |||
| renamed_dict = {} | |||
| @@ -258,7 +269,7 @@ def list_assistants(tenant_id): | |||
| reverse=True, | |||
| order_by=DialogService.model.create_time) | |||
| assts = [d.to_dict() for d in assts] | |||
| list_assts=[] | |||
| list_assts = [] | |||
| renamed_dict = {} | |||
| key_mapping = {"parameters": "variables", | |||
| "prologue": "opener", | |||
| @@ -60,7 +60,7 @@ def save(tenant_id): | |||
| req.update(mapped_keys) | |||
| if not KnowledgebaseService.save(**req): | |||
| return get_data_error_result(retmsg="Create dataset error.(Database error)") | |||
| renamed_data={} | |||
| renamed_data = {} | |||
| e, k = KnowledgebaseService.get_by_id(req["id"]) | |||
| for key, value in k.to_dict().items(): | |||
| new_key = key_mapping.get(key, key) | |||
| @@ -88,6 +88,9 @@ def save(tenant_id): | |||
| data=False, retmsg='You do not own the dataset.', | |||
| retcode=RetCode.OPERATING_ERROR) | |||
| if not req["id"]: | |||
| return get_data_error_result( | |||
| retmsg="id can not be empty.") | |||
| e, kb = KnowledgebaseService.get_by_id(req["id"]) | |||
| if "chunk_count" in req: | |||
| @@ -108,6 +111,7 @@ def save(tenant_id): | |||
| retmsg="If chunk count is not 0, parse method is not changable.") | |||
| req['parser_id'] = req.pop('parse_method') | |||
| if "name" in req: | |||
| req["name"] = req["name"].strip() | |||
| if req["name"].lower() != kb.name.lower() \ | |||
| and len(KnowledgebaseService.query(name=req["name"], tenant_id=tenant_id, | |||
| status=StatusEnum.VALID.value)) > 0: | |||
| @@ -0,0 +1,168 @@ | |||
| # | |||
| # Copyright 2024 The InfiniFlow Authors. All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # | |||
| import json | |||
| from copy import deepcopy | |||
| from uuid import uuid4 | |||
| from flask import request, Response | |||
| from api.db import StatusEnum | |||
| from api.db.services.dialog_service import DialogService, ConversationService, chat | |||
| from api.utils import get_uuid | |||
| from api.utils.api_utils import get_data_error_result | |||
| from api.utils.api_utils import get_json_result, token_required | |||
| @manager.route('/save', methods=['POST']) | |||
| @token_required | |||
| def set_conversation(tenant_id): | |||
| req = request.json | |||
| conv_id = req.get("id") | |||
| if "messages" in req: | |||
| req["message"] = req.pop("messages") | |||
| if req["message"]: | |||
| for message in req["message"]: | |||
| if "reference" in message: | |||
| req["reference"] = message.pop("reference") | |||
| if "assistant_id" in req: | |||
| req["dialog_id"] = req.pop("assistant_id") | |||
| if "id" in req: | |||
| del req["id"] | |||
| conv = ConversationService.query(id=conv_id) | |||
| if not conv: | |||
| return get_data_error_result(retmsg="Session does not exist") | |||
| if not DialogService.query(id=conv[0].dialog_id, tenant_id=tenant_id, status=StatusEnum.VALID.value): | |||
| return get_data_error_result(retmsg="You do not own the session") | |||
| if req.get("dialog_id"): | |||
| dia = DialogService.query(tenant_id=tenant_id, id=req["dialog_id"], status=StatusEnum.VALID.value) | |||
| if not dia: | |||
| return get_data_error_result(retmsg="You do not own the assistant") | |||
| if "dialog_id" in req and not req.get("dialog_id"): | |||
| return get_data_error_result(retmsg="assistant_id can not be empty.") | |||
| if "name" in req and not req.get("name"): | |||
| return get_data_error_result(retmsg="name can not be empty.") | |||
| if "message" in req and not req.get("message"): | |||
| return get_data_error_result(retmsg="messages can not be empty") | |||
| if not ConversationService.update_by_id(conv_id, req): | |||
| return get_data_error_result(retmsg="Session updates error") | |||
| return get_json_result(data=True) | |||
| if not req.get("dialog_id"): | |||
| return get_data_error_result(retmsg="assistant_id is required.") | |||
| dia = DialogService.query(tenant_id=tenant_id, id=req["dialog_id"], status=StatusEnum.VALID.value) | |||
| if not dia: | |||
| return get_data_error_result(retmsg="You do not own the assistant") | |||
| conv = { | |||
| "id": get_uuid(), | |||
| "dialog_id": req["dialog_id"], | |||
| "name": req.get("name", "New session"), | |||
| "message": req.get("message", [{"role": "assistant", "content": dia[0].prompt_config["prologue"]}]), | |||
| "reference": req.get("reference", []) | |||
| } | |||
| if not conv.get("name"): | |||
| return get_data_error_result(retmsg="name can not be empty.") | |||
| if not conv.get("message"): | |||
| return get_data_error_result(retmsg="messages can not be empty") | |||
| ConversationService.save(**conv) | |||
| e, conv = ConversationService.get_by_id(conv["id"]) | |||
| if not e: | |||
| return get_data_error_result(retmsg="Fail to new session!") | |||
| conv = conv.to_dict() | |||
| conv["messages"] = conv.pop("message") | |||
| conv["assistant_id"] = conv.pop("dialog_id") | |||
| for message in conv["messages"]: | |||
| message["reference"] = conv.get("reference") | |||
| del conv["reference"] | |||
| return get_json_result(data=conv) | |||
| @manager.route('/completion', methods=['POST']) | |||
| @token_required | |||
| def completion(tenant_id): | |||
| req = request.json | |||
| # req = {"conversation_id": "9aaaca4c11d311efa461fa163e197198", "messages": [ | |||
| # {"role": "user", "content": "上海有吗?"} | |||
| # ]} | |||
| msg = [] | |||
| question = { | |||
| "content": req.get("question"), | |||
| "role": "user", | |||
| "id": str(uuid4()) | |||
| } | |||
| req["messages"].append(question) | |||
| for m in req["messages"]: | |||
| if m["role"] == "system": continue | |||
| if m["role"] == "assistant" and not msg: continue | |||
| m["id"] = m.get("id", str(uuid4())) | |||
| msg.append(m) | |||
| message_id = msg[-1].get("id") | |||
| conv = ConversationService.query(id=req["id"]) | |||
| conv = conv[0] | |||
| if not conv: | |||
| return get_data_error_result(retmsg="Session does not exist") | |||
| if not DialogService.query(id=conv.dialog_id, tenant_id=tenant_id, status=StatusEnum.VALID.value): | |||
| return get_data_error_result(retmsg="You do not own the session") | |||
| conv.message = deepcopy(req["messages"]) | |||
| e, dia = DialogService.get_by_id(conv.dialog_id) | |||
| if not e: | |||
| return get_data_error_result(retmsg="Dialog not found!") | |||
| del req["id"] | |||
| del req["messages"] | |||
| if not conv.reference: | |||
| conv.reference = [] | |||
| conv.message.append({"role": "assistant", "content": "", "id": message_id}) | |||
| conv.reference.append({"chunks": [], "doc_aggs": []}) | |||
| def fillin_conv(ans): | |||
| nonlocal conv, message_id | |||
| if not conv.reference: | |||
| conv.reference.append(ans["reference"]) | |||
| else: | |||
| conv.reference[-1] = ans["reference"] | |||
| conv.message[-1] = {"role": "assistant", "content": ans["answer"], | |||
| "id": message_id, "prompt": ans.get("prompt", "")} | |||
| ans["id"] = message_id | |||
| def stream(): | |||
| nonlocal dia, msg, req, conv | |||
| try: | |||
| for ans in chat(dia, msg, **req): | |||
| fillin_conv(ans) | |||
| yield "data:" + json.dumps({"retcode": 0, "retmsg": "", "data": ans}, ensure_ascii=False) + "\n\n" | |||
| ConversationService.update_by_id(conv.id, conv.to_dict()) | |||
| except Exception as e: | |||
| yield "data:" + json.dumps({"retcode": 500, "retmsg": str(e), | |||
| "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, | |||
| ensure_ascii=False) + "\n\n" | |||
| yield "data:" + json.dumps({"retcode": 0, "retmsg": "", "data": True}, ensure_ascii=False) + "\n\n" | |||
| if req.get("stream", True): | |||
| resp = Response(stream(), mimetype="text/event-stream") | |||
| resp.headers.add_header("Cache-control", "no-cache") | |||
| resp.headers.add_header("Connection", "keep-alive") | |||
| resp.headers.add_header("X-Accel-Buffering", "no") | |||
| resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8") | |||
| return resp | |||
| else: | |||
| answer = None | |||
| for ans in chat(dia, msg, **req): | |||
| answer = ans | |||
| fillin_conv(ans) | |||
| ConversationService.update_by_id(conv.id, conv.to_dict()) | |||
| break | |||
| return get_json_result(data=answer) | |||
| @@ -1,9 +1,12 @@ | |||
| from typing import List | |||
| from .base import Base | |||
| from .session import Session, Message | |||
| class Assistant(Base): | |||
| def __init__(self, rag, res_dict): | |||
| self.id="" | |||
| self.id = "" | |||
| self.name = "assistant" | |||
| self.avatar = "path/to/avatar" | |||
| self.knowledgebases = ["kb1"] | |||
| @@ -41,8 +44,8 @@ class Assistant(Base): | |||
| def save(self) -> bool: | |||
| res = self.post('/assistant/save', | |||
| {"id": self.id, "name": self.name, "avatar": self.avatar, "knowledgebases":self.knowledgebases, | |||
| "llm":self.llm.to_json(),"prompt":self.prompt.to_json() | |||
| {"id": self.id, "name": self.name, "avatar": self.avatar, "knowledgebases": self.knowledgebases, | |||
| "llm": self.llm.to_json(), "prompt": self.prompt.to_json() | |||
| }) | |||
| res = res.json() | |||
| if res.get("retmsg") == "success": return True | |||
| @@ -54,3 +57,15 @@ class Assistant(Base): | |||
| res = res.json() | |||
| if res.get("retmsg") == "success": return True | |||
| raise Exception(res["retmsg"]) | |||
| def create_session(self, name: str = "New session", messages: List[Message] = [ | |||
| {"role": "assistant", "reference": [], | |||
| "content": "您好,我是您的助手小樱,长得可爱又善良,can I help you?"}]) -> Session: | |||
| res = self.post("/session/save", {"name": name, "messages": messages, "assistant_id": self.id, }) | |||
| res = res.json() | |||
| if res.get("retmsg") == "success": | |||
| return Session(self.rag, res['data']) | |||
| raise Exception(res["retmsg"]) | |||
| def get_prologue(self): | |||
| return self.prompt.opener | |||
| @@ -0,0 +1,64 @@ | |||
| import json | |||
| from .base import Base | |||
| class Session(Base): | |||
| def __init__(self, rag, res_dict): | |||
| self.id = None | |||
| self.name = "New session" | |||
| self.messages = [{"role": "assistant", "content": "Hi! I am your assistant,can I help you?"}] | |||
| self.assistant_id = None | |||
| super().__init__(rag, res_dict) | |||
| def chat(self, question: str, stream: bool = False): | |||
| res = self.post("/session/completion", | |||
| {"id": self.id, "question": question, "stream": stream, "messages": self.messages}) | |||
| res = res.text | |||
| response_lines = res.splitlines() | |||
| message_list = [] | |||
| for line in response_lines: | |||
| if line.startswith("data:"): | |||
| json_data = json.loads(line[5:]) | |||
| if json_data["data"] != True: | |||
| answer = json_data["data"]["answer"] | |||
| reference = json_data["data"]["reference"] | |||
| temp_dict = { | |||
| "content": answer, | |||
| "role": "assistant", | |||
| "reference": reference | |||
| } | |||
| message = Message(self.rag, temp_dict) | |||
| message_list.append(message) | |||
| return message_list | |||
| def save(self): | |||
| res = self.post("/session/save", | |||
| {"id": self.id, "dialog_id": self.assistant_id, "name": self.name, "message": self.messages}) | |||
| res = res.json() | |||
| if res.get("retmsg") == "success": return True | |||
| raise Exception(res.get("retmsg")) | |||
| class Message(Base): | |||
| def __init__(self, rag, res_dict): | |||
| self.content = "您好,我是您的助手小樱,长得可爱又善良,can I help you?" | |||
| self.reference = [] | |||
| self.role = "assistant" | |||
| self.prompt=None | |||
| super().__init__(rag, res_dict) | |||
| class Chunk(Base): | |||
| def __init__(self, rag, res_dict): | |||
| self.id = None | |||
| self.content = None | |||
| self.document_id = None | |||
| self.document_name = None | |||
| self.knowledgebase_id = None | |||
| self.image_id = None | |||
| self.similarity = None | |||
| self.vector_similarity = None | |||
| self.term_similarity = None | |||
| self.positions = None | |||
| super().__init__(rag, res_dict) | |||
| @@ -17,7 +17,6 @@ from typing import List | |||
| import requests | |||
| from .modules.chat_assistant import Assistant | |||
| from .modules.dataset import DataSet | |||
| @@ -88,7 +87,7 @@ class RAGFlow: | |||
| datasets.append(dataset.to_json()) | |||
| if llm is None: | |||
| llm = Assistant.LLM(self, {"model_name": "deepseek-chat", | |||
| llm = Assistant.LLM(self, {"model_name": None, | |||
| "temperature": 0.1, | |||
| "top_p": 0.3, | |||
| "presence_penalty": 0.4, | |||
| @@ -142,4 +141,4 @@ class RAGFlow: | |||
| for data in res['data']: | |||
| result_list.append(Assistant(self, data)) | |||
| return result_list | |||
| raise Exception(res["retmsg"]) | |||
| raise Exception(res["retmsg"]) | |||
| @@ -10,10 +10,10 @@ class TestAssistant(TestSdk): | |||
| Test creating an assistant with success | |||
| """ | |||
| rag = RAGFlow(API_KEY, HOST_ADDRESS) | |||
| kb = rag.get_dataset(name="God") | |||
| assistant = rag.create_assistant("God",knowledgebases=[kb]) | |||
| kb = rag.create_dataset(name="test_create_assistant") | |||
| assistant = rag.create_assistant("test_create", knowledgebases=[kb]) | |||
| if isinstance(assistant, Assistant): | |||
| assert assistant.name == "God", "Name does not match." | |||
| assert assistant.name == "test_create", "Name does not match." | |||
| else: | |||
| assert False, f"Failed to create assistant, error: {assistant}" | |||
| @@ -22,11 +22,11 @@ class TestAssistant(TestSdk): | |||
| Test updating an assistant with success. | |||
| """ | |||
| rag = RAGFlow(API_KEY, HOST_ADDRESS) | |||
| kb = rag.get_dataset(name="God") | |||
| assistant = rag.create_assistant("ABC",knowledgebases=[kb]) | |||
| kb = rag.create_dataset(name="test_update_assistant") | |||
| assistant = rag.create_assistant("test_update", knowledgebases=[kb]) | |||
| if isinstance(assistant, Assistant): | |||
| assert assistant.name == "ABC", "Name does not match." | |||
| assistant.name = 'DEF' | |||
| assert assistant.name == "test_update", "Name does not match." | |||
| assistant.name = 'new_assistant' | |||
| res = assistant.save() | |||
| assert res is True, f"Failed to update assistant, error: {res}" | |||
| else: | |||
| @@ -37,10 +37,10 @@ class TestAssistant(TestSdk): | |||
| Test deleting an assistant with success | |||
| """ | |||
| rag = RAGFlow(API_KEY, HOST_ADDRESS) | |||
| kb = rag.get_dataset(name="God") | |||
| assistant = rag.create_assistant("MA",knowledgebases=[kb]) | |||
| kb = rag.create_dataset(name="test_delete_assistant") | |||
| assistant = rag.create_assistant("test_delete", knowledgebases=[kb]) | |||
| if isinstance(assistant, Assistant): | |||
| assert assistant.name == "MA", "Name does not match." | |||
| assert assistant.name == "test_delete", "Name does not match." | |||
| res = assistant.delete() | |||
| assert res is True, f"Failed to delete assistant, error: {res}" | |||
| else: | |||
| @@ -61,6 +61,8 @@ class TestAssistant(TestSdk): | |||
| Test getting an assistant's detail with success | |||
| """ | |||
| rag = RAGFlow(API_KEY, HOST_ADDRESS) | |||
| assistant = rag.get_assistant(name="God") | |||
| kb = rag.create_dataset(name="test_get_assistant") | |||
| rag.create_assistant("test_get_assistant", knowledgebases=[kb]) | |||
| assistant = rag.get_assistant(name="test_get_assistant") | |||
| assert isinstance(assistant, Assistant), f"Failed to get assistant, error: {assistant}." | |||
| assert assistant.name == "God", "Name does not match" | |||
| assert assistant.name == "test_get_assistant", "Name does not match" | |||
| @@ -0,0 +1,27 @@ | |||
| from ragflow import RAGFlow | |||
| from common import API_KEY, HOST_ADDRESS | |||
| class TestChatSession: | |||
| def test_create_session(self): | |||
| rag = RAGFlow(API_KEY, HOST_ADDRESS) | |||
| kb = rag.create_dataset(name="test_create_session") | |||
| assistant = rag.create_assistant(name="test_create_session", knowledgebases=[kb]) | |||
| session = assistant.create_session() | |||
| assert assistant is not None, "Failed to get the assistant." | |||
| assert session is not None, "Failed to create a session." | |||
| def test_create_chat_with_success(self): | |||
| rag = RAGFlow(API_KEY, HOST_ADDRESS) | |||
| kb = rag.create_dataset(name="test_create_chat") | |||
| assistant = rag.create_assistant(name="test_create_chat", knowledgebases=[kb]) | |||
| session = assistant.create_session() | |||
| assert session is not None, "Failed to create a session." | |||
| prologue = assistant.get_prologue() | |||
| assert isinstance(prologue, str), "Prologue is not a string." | |||
| assert len(prologue) > 0, "Prologue is empty." | |||
| question = "What is AI" | |||
| ans = session.chat(question, stream=True) | |||
| response = ans[-1].content | |||
| assert len(response) > 0, "Assistant did not return any response." | |||