### What problem does this PR solve? Includes SDK for creating, updating sessions, getting sessions, listing sessions, and dialogues #1102 ### Type of change - [x] New Feature (non-breaking change which adds functionality) --------- Co-authored-by: liuhua <10215101452@stu.ecun.edu.cn>tags/v0.11.0
| # limitations under the License. | # limitations under the License. | ||||
| # | # | ||||
| import json | import json | ||||
| from copy import deepcopy | |||||
| from uuid import uuid4 | from uuid import uuid4 | ||||
| from flask import request, Response | from flask import request, Response | ||||
| from api.db import StatusEnum | from api.db import StatusEnum | ||||
| from api.db.services.dialog_service import DialogService, ConversationService, chat | from api.db.services.dialog_service import DialogService, ConversationService, chat | ||||
| from api.settings import RetCode | |||||
| from api.utils import get_uuid | from api.utils import get_uuid | ||||
| from api.utils.api_utils import get_data_error_result | from api.utils.api_utils import get_data_error_result | ||||
| from api.utils.api_utils import get_json_result, token_required | from api.utils.api_utils import get_json_result, token_required | ||||
| def set_conversation(tenant_id): | def set_conversation(tenant_id): | ||||
| req = request.json | req = request.json | ||||
| conv_id = req.get("id") | conv_id = req.get("id") | ||||
| if "messages" in req: | |||||
| req["message"] = req.pop("messages") | |||||
| if req["message"]: | |||||
| for message in req["message"]: | |||||
| if "reference" in message: | |||||
| req["reference"] = message.pop("reference") | |||||
| if "assistant_id" in req: | if "assistant_id" in req: | ||||
| req["dialog_id"] = req.pop("assistant_id") | req["dialog_id"] = req.pop("assistant_id") | ||||
| if "id" in req: | if "id" in req: | ||||
| return get_data_error_result(retmsg="You do not own the assistant") | return get_data_error_result(retmsg="You do not own the assistant") | ||||
| if "dialog_id" in req and not req.get("dialog_id"): | if "dialog_id" in req and not req.get("dialog_id"): | ||||
| return get_data_error_result(retmsg="assistant_id can not be empty.") | return get_data_error_result(retmsg="assistant_id can not be empty.") | ||||
| if "message" in req: | |||||
| return get_data_error_result(retmsg="message can not be change") | |||||
| if "reference" in req: | |||||
| return get_data_error_result(retmsg="reference can not be change") | |||||
| if "name" in req and not req.get("name"): | if "name" in req and not req.get("name"): | ||||
| return get_data_error_result(retmsg="name can not be empty.") | return get_data_error_result(retmsg="name can not be empty.") | ||||
| if "message" in req and not req.get("message"): | |||||
| return get_data_error_result(retmsg="messages can not be empty") | |||||
| if not ConversationService.update_by_id(conv_id, req): | if not ConversationService.update_by_id(conv_id, req): | ||||
| return get_data_error_result(retmsg="Session updates error") | return get_data_error_result(retmsg="Session updates error") | ||||
| return get_json_result(data=True) | return get_json_result(data=True) | ||||
| "id": get_uuid(), | "id": get_uuid(), | ||||
| "dialog_id": req["dialog_id"], | "dialog_id": req["dialog_id"], | ||||
| "name": req.get("name", "New session"), | "name": req.get("name", "New session"), | ||||
| "message": req.get("message", [{"role": "assistant", "content": dia[0].prompt_config["prologue"]}]), | |||||
| "reference": req.get("reference", []) | |||||
| "message": [{"role": "assistant", "content": "Hi! I am your assistant,can I help you?"}] | |||||
| } | } | ||||
| if not conv.get("name"): | if not conv.get("name"): | ||||
| return get_data_error_result(retmsg="name can not be empty.") | return get_data_error_result(retmsg="name can not be empty.") | ||||
| if not conv.get("message"): | |||||
| return get_data_error_result(retmsg="messages can not be empty") | |||||
| ConversationService.save(**conv) | ConversationService.save(**conv) | ||||
| e, conv = ConversationService.get_by_id(conv["id"]) | e, conv = ConversationService.get_by_id(conv["id"]) | ||||
| if not e: | if not e: | ||||
| return get_data_error_result(retmsg="Fail to new session!") | return get_data_error_result(retmsg="Fail to new session!") | ||||
| conv = conv.to_dict() | conv = conv.to_dict() | ||||
| conv["messages"] = conv.pop("message") | |||||
| conv['messages'] = conv.pop("message") | |||||
| conv["assistant_id"] = conv.pop("dialog_id") | conv["assistant_id"] = conv.pop("dialog_id") | ||||
| for message in conv["messages"]: | |||||
| message["reference"] = conv.get("reference") | |||||
| del conv["reference"] | del conv["reference"] | ||||
| return get_json_result(data=conv) | return get_json_result(data=conv) | ||||
| # req = {"conversation_id": "9aaaca4c11d311efa461fa163e197198", "messages": [ | # req = {"conversation_id": "9aaaca4c11d311efa461fa163e197198", "messages": [ | ||||
| # {"role": "user", "content": "上海有吗?"} | # {"role": "user", "content": "上海有吗?"} | ||||
| # ]} | # ]} | ||||
| if "id" not in req: | |||||
| return get_data_error_result(retmsg="id is required") | |||||
| conv = ConversationService.query(id=req["id"]) | |||||
| if not conv: | |||||
| return get_data_error_result(retmsg="Session does not exist") | |||||
| conv = conv[0] | |||||
| if not DialogService.query(id=conv.dialog_id, tenant_id=tenant_id, status=StatusEnum.VALID.value): | |||||
| return get_data_error_result(retmsg="You do not own the session") | |||||
| msg = [] | msg = [] | ||||
| question = { | question = { | ||||
| "content": req.get("question"), | "content": req.get("question"), | ||||
| "role": "user", | "role": "user", | ||||
| "id": str(uuid4()) | "id": str(uuid4()) | ||||
| } | } | ||||
| req["messages"].append(question) | |||||
| for m in req["messages"]: | |||||
| conv.message.append(question) | |||||
| for m in conv.message: | |||||
| if m["role"] == "system": continue | if m["role"] == "system": continue | ||||
| if m["role"] == "assistant" and not msg: continue | if m["role"] == "assistant" and not msg: continue | ||||
| m["id"] = m.get("id", str(uuid4())) | |||||
| msg.append(m) | msg.append(m) | ||||
| message_id = msg[-1].get("id") | message_id = msg[-1].get("id") | ||||
| conv = ConversationService.query(id=req["id"]) | |||||
| conv = conv[0] | |||||
| if not conv: | |||||
| return get_data_error_result(retmsg="Session does not exist") | |||||
| if not DialogService.query(id=conv.dialog_id, tenant_id=tenant_id, status=StatusEnum.VALID.value): | |||||
| return get_data_error_result(retmsg="You do not own the session") | |||||
| conv.message = deepcopy(req["messages"]) | |||||
| e, dia = DialogService.get_by_id(conv.dialog_id) | e, dia = DialogService.get_by_id(conv.dialog_id) | ||||
| if not e: | |||||
| return get_data_error_result(retmsg="Dialog not found!") | |||||
| del req["id"] | del req["id"] | ||||
| del req["messages"] | |||||
| if not conv.reference: | if not conv.reference: | ||||
| conv.reference = [] | conv.reference = [] | ||||
| ConversationService.update_by_id(conv.id, conv.to_dict()) | ConversationService.update_by_id(conv.id, conv.to_dict()) | ||||
| break | break | ||||
| return get_json_result(data=answer) | return get_json_result(data=answer) | ||||
| @manager.route('/get', methods=['GET']) | |||||
| @token_required | |||||
| def get(tenant_id): | |||||
| req = request.args | |||||
| if "id" not in req: | |||||
| return get_data_error_result(retmsg="id is required") | |||||
| conv_id = req["id"] | |||||
| conv = ConversationService.query(id=conv_id) | |||||
| if not conv: | |||||
| return get_data_error_result(retmsg="Session does not exist") | |||||
| if not DialogService.query(id=conv[0].dialog_id, tenant_id=tenant_id, status=StatusEnum.VALID.value): | |||||
| return get_data_error_result(retmsg="You do not own the session") | |||||
| conv = conv[0].to_dict() | |||||
| conv['messages'] = conv.pop("message") | |||||
| conv["assistant_id"] = conv.pop("dialog_id") | |||||
| if conv["reference"]: | |||||
| messages = conv["messages"] | |||||
| message_num = 0 | |||||
| chunk_num = 0 | |||||
| while message_num < len(messages): | |||||
| if message_num != 0 and messages[message_num]["role"] != "user": | |||||
| chunk_list = [] | |||||
| if "chunks" in conv["reference"][chunk_num]: | |||||
| chunks = conv["reference"][chunk_num]["chunks"] | |||||
| for chunk in chunks: | |||||
| new_chunk = { | |||||
| "id": chunk["chunk_id"], | |||||
| "content": chunk["content_with_weight"], | |||||
| "document_id": chunk["doc_id"], | |||||
| "document_name": chunk["docnm_kwd"], | |||||
| "knowledgebase_id": chunk["kb_id"], | |||||
| "image_id": chunk["img_id"], | |||||
| "similarity": chunk["similarity"], | |||||
| "vector_similarity": chunk["vector_similarity"], | |||||
| "term_similarity": chunk["term_similarity"], | |||||
| "positions": chunk["positions"], | |||||
| } | |||||
| chunk_list.append(new_chunk) | |||||
| chunk_num += 1 | |||||
| messages[message_num]["reference"] = chunk_list | |||||
| message_num += 1 | |||||
| del conv["reference"] | |||||
| return get_json_result(data=conv) | |||||
| @manager.route('/list', methods=["GET"]) | |||||
| @token_required | |||||
| def list(tenant_id): | |||||
| assistant_id = request.args["assistant_id"] | |||||
| if not DialogService.query(tenant_id=tenant_id, id=assistant_id, status=StatusEnum.VALID.value): | |||||
| return get_json_result( | |||||
| data=False, retmsg=f'Only owner of the assistant is authorized for this operation.', | |||||
| retcode=RetCode.OPERATING_ERROR) | |||||
| convs = ConversationService.query( | |||||
| dialog_id=assistant_id, | |||||
| order_by=ConversationService.model.create_time, | |||||
| reverse=True) | |||||
| convs = [d.to_dict() for d in convs] | |||||
| for conv in convs: | |||||
| conv['messages'] = conv.pop("message") | |||||
| conv["assistant_id"] = conv.pop("dialog_id") | |||||
| if conv["reference"]: | |||||
| messages = conv["messages"] | |||||
| message_num = 0 | |||||
| chunk_num = 0 | |||||
| while message_num < len(messages): | |||||
| if message_num != 0 and messages[message_num]["role"] != "user": | |||||
| chunk_list = [] | |||||
| if "chunks" in conv["reference"][chunk_num]: | |||||
| chunks = conv["reference"][chunk_num]["chunks"] | |||||
| for chunk in chunks: | |||||
| new_chunk = { | |||||
| "id": chunk["chunk_id"], | |||||
| "content": chunk["content_with_weight"], | |||||
| "document_id": chunk["doc_id"], | |||||
| "document_name": chunk["docnm_kwd"], | |||||
| "knowledgebase_id": chunk["kb_id"], | |||||
| "image_id": chunk["img_id"], | |||||
| "similarity": chunk["similarity"], | |||||
| "vector_similarity": chunk["vector_similarity"], | |||||
| "term_similarity": chunk["term_similarity"], | |||||
| "positions": chunk["positions"], | |||||
| } | |||||
| chunk_list.append(new_chunk) | |||||
| chunk_num += 1 | |||||
| messages[message_num]["reference"] = chunk_list | |||||
| message_num += 1 | |||||
| del conv["reference"] | |||||
| return get_json_result(data=convs) | |||||
| @manager.route('/delete', methods=["DELETE"]) | |||||
| @token_required | |||||
| def delete(tenant_id): | |||||
| id = request.args.get("id") | |||||
| if not id: | |||||
| return get_data_error_result(retmsg="`id` is required in deleting operation") | |||||
| conv = ConversationService.query(id=id) | |||||
| if not conv: | |||||
| return get_data_error_result(retmsg="Session doesn't exist") | |||||
| conv = conv[0] | |||||
| if not DialogService.query(id=conv.dialog_id, tenant_id=tenant_id, status=StatusEnum.VALID.value): | |||||
| return get_data_error_result(retmsg="You don't own the session") | |||||
| ConversationService.delete_by_id(id) | |||||
| return get_json_result(data=True) |
| from .ragflow import RAGFlow | from .ragflow import RAGFlow | ||||
| from .modules.dataset import DataSet | from .modules.dataset import DataSet | ||||
| from .modules.chat_assistant import Assistant | |||||
| from .modules.assistant import Assistant | |||||
| from .modules.session import Session |
| from typing import List | |||||
| from .base import Base | |||||
| from .session import Session, Message | |||||
| class Assistant(Base): | |||||
| def __init__(self, rag, res_dict): | |||||
| self.id = "" | |||||
| self.name = "assistant" | |||||
| self.avatar = "path/to/avatar" | |||||
| self.knowledgebases = ["kb1"] | |||||
| self.llm = Assistant.LLM(rag, {}) | |||||
| self.prompt = Assistant.Prompt(rag, {}) | |||||
| super().__init__(rag, res_dict) | |||||
| class LLM(Base): | |||||
| def __init__(self, rag, res_dict): | |||||
| self.model_name = "deepseek-chat" | |||||
| self.temperature = 0.1 | |||||
| self.top_p = 0.3 | |||||
| self.presence_penalty = 0.4 | |||||
| self.frequency_penalty = 0.7 | |||||
| self.max_tokens = 512 | |||||
| super().__init__(rag, res_dict) | |||||
| class Prompt(Base): | |||||
| def __init__(self, rag, res_dict): | |||||
| self.similarity_threshold = 0.2 | |||||
| self.keywords_similarity_weight = 0.7 | |||||
| self.top_n = 8 | |||||
| self.variables = [{"key": "knowledge", "optional": True}] | |||||
| self.rerank_model = None | |||||
| self.empty_response = None | |||||
| self.opener = "Hi! I'm your assistant, what can I do for you?" | |||||
| self.show_quote = True | |||||
| self.prompt = ( | |||||
| "You are an intelligent assistant. Please summarize the content of the knowledge base to answer the question. " | |||||
| "Please list the data in the knowledge base and answer in detail. When all knowledge base content is irrelevant to the question, " | |||||
| "your answer must include the sentence 'The answer you are looking for is not found in the knowledge base!' " | |||||
| "Answers need to consider chat history.\nHere is the knowledge base:\n{knowledge}\nThe above is the knowledge base." | |||||
| ) | |||||
| super().__init__(rag, res_dict) | |||||
| def save(self) -> bool: | |||||
| res = self.post('/assistant/save', | |||||
| {"id": self.id, "name": self.name, "avatar": self.avatar, "knowledgebases": self.knowledgebases, | |||||
| "llm": self.llm.to_json(), "prompt": self.prompt.to_json() | |||||
| }) | |||||
| res = res.json() | |||||
| if res.get("retmsg") == "success": return True | |||||
| raise Exception(res["retmsg"]) | |||||
| def delete(self) -> bool: | |||||
| res = self.rm('/assistant/delete', | |||||
| {"id": self.id}) | |||||
| res = res.json() | |||||
| if res.get("retmsg") == "success": return True | |||||
| raise Exception(res["retmsg"]) | |||||
| def create_session(self, name: str = "New session", messages: List[Message] = [ | |||||
| {"role": "assistant", "reference": [], | |||||
| "content": "您好,我是您的助手小樱,长得可爱又善良,can I help you?"}]) -> Session: | |||||
| res = self.post("/session/save", {"name": name, "messages": messages, "assistant_id": self.id, }) | |||||
| res = res.json() | |||||
| if res.get("retmsg") == "success": | |||||
| return Session(self.rag, res['data']) | |||||
| raise Exception(res["retmsg"]) | |||||
| def get_prologue(self): | |||||
| return self.prompt.opener | |||||
| from typing import List | |||||
| from .base import Base | |||||
| from .session import Session | |||||
| class Assistant(Base): | |||||
| def __init__(self, rag, res_dict): | |||||
| self.id = "" | |||||
| self.name = "assistant" | |||||
| self.avatar = "path/to/avatar" | |||||
| self.knowledgebases = ["kb1"] | |||||
| self.llm = Assistant.LLM(rag, {}) | |||||
| self.prompt = Assistant.Prompt(rag, {}) | |||||
| super().__init__(rag, res_dict) | |||||
| class LLM(Base): | |||||
| def __init__(self, rag, res_dict): | |||||
| self.model_name = "deepseek-chat" | |||||
| self.temperature = 0.1 | |||||
| self.top_p = 0.3 | |||||
| self.presence_penalty = 0.4 | |||||
| self.frequency_penalty = 0.7 | |||||
| self.max_tokens = 512 | |||||
| super().__init__(rag, res_dict) | |||||
| class Prompt(Base): | |||||
| def __init__(self, rag, res_dict): | |||||
| self.similarity_threshold = 0.2 | |||||
| self.keywords_similarity_weight = 0.7 | |||||
| self.top_n = 8 | |||||
| self.variables = [{"key": "knowledge", "optional": True}] | |||||
| self.rerank_model = None | |||||
| self.empty_response = None | |||||
| self.opener = "Hi! I'm your assistant, what can I do for you?" | |||||
| self.show_quote = True | |||||
| self.prompt = ( | |||||
| "You are an intelligent assistant. Please summarize the content of the knowledge base to answer the question. " | |||||
| "Please list the data in the knowledge base and answer in detail. When all knowledge base content is irrelevant to the question, " | |||||
| "your answer must include the sentence 'The answer you are looking for is not found in the knowledge base!' " | |||||
| "Answers need to consider chat history.\nHere is the knowledge base:\n{knowledge}\nThe above is the knowledge base." | |||||
| ) | |||||
| super().__init__(rag, res_dict) | |||||
| def save(self) -> bool: | |||||
| res = self.post('/assistant/save', | |||||
| {"id": self.id, "name": self.name, "avatar": self.avatar, "knowledgebases": self.knowledgebases, | |||||
| "llm": self.llm.to_json(), "prompt": self.prompt.to_json() | |||||
| }) | |||||
| res = res.json() | |||||
| if res.get("retmsg") == "success": return True | |||||
| raise Exception(res["retmsg"]) | |||||
| def delete(self) -> bool: | |||||
| res = self.rm('/assistant/delete', | |||||
| {"id": self.id}) | |||||
| res = res.json() | |||||
| if res.get("retmsg") == "success": return True | |||||
| raise Exception(res["retmsg"]) | |||||
| def create_session(self, name: str = "New session") -> Session: | |||||
| res = self.post("/session/save", {"name": name, "assistant_id": self.id}) | |||||
| res = res.json() | |||||
| if res.get("retmsg") == "success": | |||||
| return Session(self.rag, res['data']) | |||||
| raise Exception(res["retmsg"]) | |||||
| def list_session(self) -> List[Session]: | |||||
| res = self.get('/session/list', {"assistant_id": self.id}) | |||||
| res = res.json() | |||||
| if res.get("retmsg") == "success": | |||||
| result_list = [] | |||||
| for data in res["data"]: | |||||
| result_list.append(Session(self.rag, data)) | |||||
| return result_list | |||||
| raise Exception(res["retmsg"]) | |||||
| def get_session(self, id) -> Session: | |||||
| res = self.get("/session/get", {"id": id}) | |||||
| res = res.json() | |||||
| if res.get("retmsg") == "success": | |||||
| return Session(self.rag, res["data"]) | |||||
| raise Exception(res["retmsg"]) | |||||
| def get_prologue(self): | |||||
| return self.prompt.opener |
| pr[name] = value | pr[name] = value | ||||
| return pr | return pr | ||||
| def post(self, path, param): | |||||
| res = self.rag.post(path, param) | |||||
| def post(self, path, param, stream=False): | |||||
| res = self.rag.post(path, param, stream=stream) | |||||
| return res | return res | ||||
| def get(self, path, params): | def get(self, path, params): |
| self.id = None | self.id = None | ||||
| self.name = "New session" | self.name = "New session" | ||||
| self.messages = [{"role": "assistant", "content": "Hi! I am your assistant,can I help you?"}] | self.messages = [{"role": "assistant", "content": "Hi! I am your assistant,can I help you?"}] | ||||
| self.assistant_id = None | self.assistant_id = None | ||||
| super().__init__(rag, res_dict) | super().__init__(rag, res_dict) | ||||
| def chat(self, question: str, stream: bool = False): | def chat(self, question: str, stream: bool = False): | ||||
| for message in self.messages: | |||||
| if "reference" in message: | |||||
| message.pop("reference") | |||||
| res = self.post("/session/completion", | res = self.post("/session/completion", | ||||
| {"id": self.id, "question": question, "stream": stream, "messages": self.messages}) | |||||
| res = res.text | |||||
| response_lines = res.splitlines() | |||||
| message_list = [] | |||||
| for line in response_lines: | |||||
| {"id": self.id, "question": question, "stream": stream}, stream=True) | |||||
| for line in res.iter_lines(): | |||||
| line = line.decode("utf-8") | |||||
| if line.startswith("data:"): | if line.startswith("data:"): | ||||
| json_data = json.loads(line[5:]) | json_data = json.loads(line[5:]) | ||||
| if json_data["data"] != True: | if json_data["data"] != True: | ||||
| reference = json_data["data"]["reference"] | reference = json_data["data"]["reference"] | ||||
| temp_dict = { | temp_dict = { | ||||
| "content": answer, | "content": answer, | ||||
| "role": "assistant", | |||||
| "reference": reference | |||||
| "role": "assistant" | |||||
| } | } | ||||
| if "chunks" in reference: | |||||
| chunks = reference["chunks"] | |||||
| chunk_list = [] | |||||
| for chunk in chunks: | |||||
| new_chunk = { | |||||
| "id": chunk["chunk_id"], | |||||
| "content": chunk["content_with_weight"], | |||||
| "document_id": chunk["doc_id"], | |||||
| "document_name": chunk["docnm_kwd"], | |||||
| "knowledgebase_id": chunk["kb_id"], | |||||
| "image_id": chunk["img_id"], | |||||
| "similarity": chunk["similarity"], | |||||
| "vector_similarity": chunk["vector_similarity"], | |||||
| "term_similarity": chunk["term_similarity"], | |||||
| "positions": chunk["positions"], | |||||
| } | |||||
| chunk_list.append(new_chunk) | |||||
| temp_dict["reference"] = chunk_list | |||||
| message = Message(self.rag, temp_dict) | message = Message(self.rag, temp_dict) | ||||
| message_list.append(message) | |||||
| return message_list | |||||
| yield message | |||||
| def save(self): | def save(self): | ||||
| res = self.post("/session/save", | res = self.post("/session/save", | ||||
| {"id": self.id, "dialog_id": self.assistant_id, "name": self.name, "message": self.messages}) | |||||
| {"id": self.id, "assistant_id": self.assistant_id, "name": self.name}) | |||||
| res = res.json() | |||||
| if res.get("retmsg") == "success": return True | |||||
| raise Exception(res.get("retmsg")) | |||||
| def delete(self): | |||||
| res = self.rm("/session/delete", {"id": self.id}) | |||||
| res = res.json() | res = res.json() | ||||
| if res.get("retmsg") == "success": return True | if res.get("retmsg") == "success": return True | ||||
| raise Exception(res.get("retmsg")) | raise Exception(res.get("retmsg")) | ||||
| class Message(Base): | class Message(Base): | ||||
| def __init__(self, rag, res_dict): | def __init__(self, rag, res_dict): | ||||
| self.content = "您好,我是您的助手小樱,长得可爱又善良,can I help you?" | |||||
| self.reference = [] | |||||
| self.content = "Hi! I am your assistant,can I help you?" | |||||
| self.reference = None | |||||
| self.role = "assistant" | self.role = "assistant" | ||||
| self.prompt=None | |||||
| self.prompt = None | |||||
| super().__init__(rag, res_dict) | super().__init__(rag, res_dict) | ||||
| import requests | import requests | ||||
| from .modules.chat_assistant import Assistant | |||||
| from .modules.assistant import Assistant | |||||
| from .modules.dataset import DataSet | from .modules.dataset import DataSet | ||||
| self.api_url = f"{base_url}/api/{version}" | self.api_url = f"{base_url}/api/{version}" | ||||
| self.authorization_header = {"Authorization": "{} {}".format("Bearer", self.user_key)} | self.authorization_header = {"Authorization": "{} {}".format("Bearer", self.user_key)} | ||||
| def post(self, path, param): | |||||
| res = requests.post(url=self.api_url + path, json=param, headers=self.authorization_header) | |||||
| def post(self, path, param, stream=False): | |||||
| res = requests.post(url=self.api_url + path, json=param, headers=self.authorization_header, stream=stream) | |||||
| return res | return res | ||||
| def get(self, path, params=None): | def get(self, path, params=None): |
| from ragflow import RAGFlow | |||||
| from ragflow import RAGFlow,Session | |||||
| from common import API_KEY, HOST_ADDRESS | from common import API_KEY, HOST_ADDRESS | ||||
| class TestChatSession: | |||||
| class TestSession: | |||||
| def test_create_session(self): | def test_create_session(self): | ||||
| rag = RAGFlow(API_KEY, HOST_ADDRESS) | rag = RAGFlow(API_KEY, HOST_ADDRESS) | ||||
| kb = rag.create_dataset(name="test_create_session") | kb = rag.create_dataset(name="test_create_session") | ||||
| assistant = rag.create_assistant(name="test_create_session", knowledgebases=[kb]) | assistant = rag.create_assistant(name="test_create_session", knowledgebases=[kb]) | ||||
| session = assistant.create_session() | session = assistant.create_session() | ||||
| assert assistant is not None, "Failed to get the assistant." | |||||
| assert session is not None, "Failed to create a session." | |||||
| assert isinstance(session,Session), "Failed to create a session." | |||||
| def test_create_chat_with_success(self): | def test_create_chat_with_success(self): | ||||
| rag = RAGFlow(API_KEY, HOST_ADDRESS) | rag = RAGFlow(API_KEY, HOST_ADDRESS) | ||||
| kb = rag.create_dataset(name="test_create_chat") | kb = rag.create_dataset(name="test_create_chat") | ||||
| assistant = rag.create_assistant(name="test_create_chat", knowledgebases=[kb]) | assistant = rag.create_assistant(name="test_create_chat", knowledgebases=[kb]) | ||||
| session = assistant.create_session() | session = assistant.create_session() | ||||
| assert session is not None, "Failed to create a session." | |||||
| prologue = assistant.get_prologue() | |||||
| assert isinstance(prologue, str), "Prologue is not a string." | |||||
| assert len(prologue) > 0, "Prologue is empty." | |||||
| question = "What is AI" | question = "What is AI" | ||||
| ans = session.chat(question, stream=True) | |||||
| response = ans[-1].content | |||||
| assert len(response) > 0, "Assistant did not return any response." | |||||
| for ans in session.chat(question, stream=True): | |||||
| pass | |||||
| assert ans.content!="\n**ERROR**", "Please check this error." | |||||
| def test_delete_session_with_success(self): | |||||
| rag = RAGFlow(API_KEY, HOST_ADDRESS) | |||||
| kb = rag.create_dataset(name="test_delete_session") | |||||
| assistant = rag.create_assistant(name="test_delete_session",knowledgebases=[kb]) | |||||
| session=assistant.create_session() | |||||
| res=session.delete() | |||||
| assert res, "Failed to delete the dataset." | |||||
| def test_update_session_with_success(self): | |||||
| rag=RAGFlow(API_KEY,HOST_ADDRESS) | |||||
| kb=rag.create_dataset(name="test_update_session") | |||||
| assistant = rag.create_assistant(name="test_update_session",knowledgebases=[kb]) | |||||
| session=assistant.create_session(name="old session") | |||||
| session.name="new session" | |||||
| res=session.save() | |||||
| assert res,"Failed to update the session" | |||||
| def test_get_session_with_success(self): | |||||
| rag=RAGFlow(API_KEY,HOST_ADDRESS) | |||||
| kb=rag.create_dataset(name="test_get_session") | |||||
| assistant = rag.create_assistant(name="test_get_session",knowledgebases=[kb]) | |||||
| session = assistant.create_session() | |||||
| session_2= assistant.get_session(id=session.id) | |||||
| assert session.to_json()==session_2.to_json(),"Failed to get the session" | |||||
| def test_list_session_with_success(self): | |||||
| rag=RAGFlow(API_KEY,HOST_ADDRESS) | |||||
| kb=rag.create_dataset(name="test_list_session") | |||||
| assistant=rag.create_assistant(name="test_list_session",knowledgebases=[kb]) | |||||
| assistant.create_session("test_1") | |||||
| assistant.create_session("test_2") | |||||
| sessions=assistant.list_session() | |||||
| if isinstance(sessions,list): | |||||
| for session in sessions: | |||||
| assert isinstance(session,Session),"Non-Session elements exist in the list" | |||||
| else : | |||||
| assert False,"Failed to retrieve the session list." |