### What problem does this PR solve? #709 ### Type of change - [x] New Feature (non-breaking change which adds functionality)tags/v0.6.0
| @@ -13,10 +13,11 @@ | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # | |||
| import json | |||
| import os | |||
| import re | |||
| from datetime import datetime, timedelta | |||
| from flask import request | |||
| from flask import request, Response | |||
| from flask_login import login_required, current_user | |||
| from api.db import FileType, ParserType | |||
| @@ -31,11 +32,11 @@ from api.settings import RetCode | |||
| from api.utils import get_uuid, current_timestamp, datetime_format | |||
| from api.utils.api_utils import server_error_response, get_data_error_result, get_json_result, validate_request | |||
| from itsdangerous import URLSafeTimedSerializer | |||
| from api.db.services.task_service import TaskService, queue_tasks | |||
| from api.utils.file_utils import filename_type, thumbnail | |||
| from rag.utils.minio_conn import MINIO | |||
| from api.db.db_models import Task | |||
| from api.db.services.file2document_service import File2DocumentService | |||
| def generate_confirmation_token(tenent_id): | |||
| serializer = URLSafeTimedSerializer(tenent_id) | |||
| return "ragflow-" + serializer.dumps(get_uuid(), salt=tenent_id)[2:34] | |||
| @@ -164,6 +165,7 @@ def completion(): | |||
| e, conv = API4ConversationService.get_by_id(req["conversation_id"]) | |||
| if not e: | |||
| return get_data_error_result(retmsg="Conversation not found!") | |||
| if "quote" not in req: req["quote"] = False | |||
| msg = [] | |||
| for m in req["messages"]: | |||
| @@ -180,13 +182,45 @@ def completion(): | |||
| return get_data_error_result(retmsg="Dialog not found!") | |||
| del req["conversation_id"] | |||
| del req["messages"] | |||
| ans = chat(dia, msg, **req) | |||
| if not conv.reference: | |||
| conv.reference = [] | |||
| conv.reference.append(ans["reference"]) | |||
| conv.message.append({"role": "assistant", "content": ans["answer"]}) | |||
| API4ConversationService.append_message(conv.id, conv.to_dict()) | |||
| return get_json_result(data=ans) | |||
| conv.message.append({"role": "assistant", "content": ""}) | |||
| conv.reference.append({"chunks": [], "doc_aggs": []}) | |||
| def fillin_conv(ans): | |||
| nonlocal conv | |||
| if not conv.reference: | |||
| conv.reference.append(ans["reference"]) | |||
| else: conv.reference[-1] = ans["reference"] | |||
| conv.message[-1] = {"role": "assistant", "content": ans["answer"]} | |||
| def stream(): | |||
| nonlocal dia, msg, req, conv | |||
| try: | |||
| for ans in chat(dia, msg, True, **req): | |||
| fillin_conv(ans) | |||
| yield "data:"+json.dumps({"retcode": 0, "retmsg": "", "data": ans}, ensure_ascii=False) + "\n\n" | |||
| API4ConversationService.append_message(conv.id, conv.to_dict()) | |||
| except Exception as e: | |||
| yield "data:" + json.dumps({"retcode": 500, "retmsg": str(e), | |||
| "data": {"answer": "**ERROR**: "+str(e), "reference": []}}, | |||
| ensure_ascii=False) + "\n\n" | |||
| yield "data:"+json.dumps({"retcode": 0, "retmsg": "", "data": True}, ensure_ascii=False) + "\n\n" | |||
| if req.get("stream", True): | |||
| resp = Response(stream(), mimetype="text/event-stream") | |||
| resp.headers.add_header("Cache-control", "no-cache") | |||
| resp.headers.add_header("Connection", "keep-alive") | |||
| resp.headers.add_header("X-Accel-Buffering", "no") | |||
| resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8") | |||
| return resp | |||
| else: | |||
| ans = chat(dia, msg, False, **req) | |||
| fillin_conv(ans) | |||
| API4ConversationService.append_message(conv.id, conv.to_dict()) | |||
| return get_json_result(data=ans) | |||
| except Exception as e: | |||
| return server_error_response(e) | |||
| @@ -229,7 +263,6 @@ def upload(): | |||
| return get_json_result( | |||
| data=False, retmsg='No file part!', retcode=RetCode.ARGUMENT_ERROR) | |||
| file = request.files['file'] | |||
| if file.filename == '': | |||
| return get_json_result( | |||
| @@ -253,7 +286,6 @@ def upload(): | |||
| location += "_" | |||
| blob = request.files['file'].read() | |||
| MINIO.put(kb_id, location, blob) | |||
| doc = { | |||
| "id": get_uuid(), | |||
| "kb_id": kb.id, | |||
| @@ -266,42 +298,11 @@ def upload(): | |||
| "size": len(blob), | |||
| "thumbnail": thumbnail(filename, blob) | |||
| } | |||
| form_data=request.form | |||
| if "parser_id" in form_data.keys(): | |||
| if request.form.get("parser_id").strip() in list(vars(ParserType).values())[1:-3]: | |||
| doc["parser_id"] = request.form.get("parser_id").strip() | |||
| if doc["type"] == FileType.VISUAL: | |||
| doc["parser_id"] = ParserType.PICTURE.value | |||
| if re.search(r"\.(ppt|pptx|pages)$", filename): | |||
| doc["parser_id"] = ParserType.PRESENTATION.value | |||
| doc_result = DocumentService.insert(doc) | |||
| doc = DocumentService.insert(doc) | |||
| return get_json_result(data=doc.to_json()) | |||
| except Exception as e: | |||
| return server_error_response(e) | |||
| if "run" in form_data.keys(): | |||
| if request.form.get("run").strip() == "1": | |||
| try: | |||
| info = {"run": 1, "progress": 0} | |||
| info["progress_msg"] = "" | |||
| info["chunk_num"] = 0 | |||
| info["token_num"] = 0 | |||
| DocumentService.update_by_id(doc["id"], info) | |||
| # if str(req["run"]) == TaskStatus.CANCEL.value: | |||
| tenant_id = DocumentService.get_tenant_id(doc["id"]) | |||
| if not tenant_id: | |||
| return get_data_error_result(retmsg="Tenant not found!") | |||
| #e, doc = DocumentService.get_by_id(doc["id"]) | |||
| TaskService.filter_delete([Task.doc_id == doc["id"]]) | |||
| e, doc = DocumentService.get_by_id(doc["id"]) | |||
| doc = doc.to_dict() | |||
| doc["tenant_id"] = tenant_id | |||
| bucket, name = File2DocumentService.get_minio_address(doc_id=doc["id"]) | |||
| queue_tasks(doc, bucket, name) | |||
| except Exception as e: | |||
| return server_error_response(e) | |||
| return get_json_result(data=doc_result.to_json()) | |||
| @@ -13,12 +13,13 @@ | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # | |||
| from flask import request | |||
| from flask import request, Response, jsonify | |||
| from flask_login import login_required | |||
| from api.db.services.dialog_service import DialogService, ConversationService, chat | |||
| from api.utils.api_utils import server_error_response, get_data_error_result, validate_request | |||
| from api.utils import get_uuid | |||
| from api.utils.api_utils import get_json_result | |||
| import json | |||
| @manager.route('/set', methods=['POST']) | |||
| @@ -103,9 +104,12 @@ def list_convsersation(): | |||
| @manager.route('/completion', methods=['POST']) | |||
| @login_required | |||
| @validate_request("conversation_id", "messages") | |||
| #@validate_request("conversation_id", "messages") | |||
| def completion(): | |||
| req = request.json | |||
| #req = {"conversation_id": "9aaaca4c11d311efa461fa163e197198", "messages": [ | |||
| # {"role": "user", "content": "上海有吗?"} | |||
| #]} | |||
| msg = [] | |||
| for m in req["messages"]: | |||
| if m["role"] == "system": | |||
| @@ -123,13 +127,45 @@ def completion(): | |||
| return get_data_error_result(retmsg="Dialog not found!") | |||
| del req["conversation_id"] | |||
| del req["messages"] | |||
| ans = chat(dia, msg, **req) | |||
| if not conv.reference: | |||
| conv.reference = [] | |||
| conv.reference.append(ans["reference"]) | |||
| conv.message.append({"role": "assistant", "content": ans["answer"]}) | |||
| ConversationService.update_by_id(conv.id, conv.to_dict()) | |||
| return get_json_result(data=ans) | |||
| conv.message.append({"role": "assistant", "content": ""}) | |||
| conv.reference.append({"chunks": [], "doc_aggs": []}) | |||
| def fillin_conv(ans): | |||
| nonlocal conv | |||
| if not conv.reference: | |||
| conv.reference.append(ans["reference"]) | |||
| else: conv.reference[-1] = ans["reference"] | |||
| conv.message[-1] = {"role": "assistant", "content": ans["answer"]} | |||
| def stream(): | |||
| nonlocal dia, msg, req, conv | |||
| try: | |||
| for ans in chat(dia, msg, True, **req): | |||
| fillin_conv(ans) | |||
| yield "data:"+json.dumps({"retcode": 0, "retmsg": "", "data": ans}, ensure_ascii=False) + "\n\n" | |||
| ConversationService.update_by_id(conv.id, conv.to_dict()) | |||
| except Exception as e: | |||
| yield "data:" + json.dumps({"retcode": 500, "retmsg": str(e), | |||
| "data": {"answer": "**ERROR**: "+str(e), "reference": []}}, | |||
| ensure_ascii=False) + "\n\n" | |||
| yield "data:"+json.dumps({"retcode": 0, "retmsg": "", "data": True}, ensure_ascii=False) + "\n\n" | |||
| if req.get("stream", True): | |||
| resp = Response(stream(), mimetype="text/event-stream") | |||
| resp.headers.add_header("Cache-control", "no-cache") | |||
| resp.headers.add_header("Connection", "keep-alive") | |||
| resp.headers.add_header("X-Accel-Buffering", "no") | |||
| resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8") | |||
| return resp | |||
| else: | |||
| ans = chat(dia, msg, False, **req) | |||
| fillin_conv(ans) | |||
| ConversationService.update_by_id(conv.id, conv.to_dict()) | |||
| return get_json_result(data=ans) | |||
| except Exception as e: | |||
| return server_error_response(e) | |||
| @@ -0,0 +1,67 @@ | |||
| # | |||
| # Copyright 2024 The InfiniFlow Authors. All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License | |||
| # | |||
| from flask_login import login_required | |||
| from api.db.services.knowledgebase_service import KnowledgebaseService | |||
| from api.utils.api_utils import get_json_result | |||
| from api.versions import get_rag_version | |||
| from rag.settings import SVR_QUEUE_NAME | |||
| from rag.utils.es_conn import ELASTICSEARCH | |||
| from rag.utils.minio_conn import MINIO | |||
| from timeit import default_timer as timer | |||
| from rag.utils.redis_conn import REDIS_CONN | |||
| @manager.route('/version', methods=['GET']) | |||
| @login_required | |||
| def version(): | |||
| return get_json_result(data=get_rag_version()) | |||
| @manager.route('/status', methods=['GET']) | |||
| @login_required | |||
| def status(): | |||
| res = {} | |||
| st = timer() | |||
| try: | |||
| res["es"] = ELASTICSEARCH.health() | |||
| res["es"]["elapsed"] = "{:.1f}".format((timer() - st)*1000.) | |||
| except Exception as e: | |||
| res["es"] = {"status": "red", "elapsed": "{:.1f}".format((timer() - st)*1000.), "error": str(e)} | |||
| st = timer() | |||
| try: | |||
| MINIO.health() | |||
| res["minio"] = {"status": "green", "elapsed": "{:.1f}".format((timer() - st)*1000.)} | |||
| except Exception as e: | |||
| res["minio"] = {"status": "red", "elapsed": "{:.1f}".format((timer() - st)*1000.), "error": str(e)} | |||
| st = timer() | |||
| try: | |||
| KnowledgebaseService.get_by_id("x") | |||
| res["mysql"] = {"status": "green", "elapsed": "{:.1f}".format((timer() - st)*1000.)} | |||
| except Exception as e: | |||
| res["mysql"] = {"status": "red", "elapsed": "{:.1f}".format((timer() - st)*1000.), "error": str(e)} | |||
| st = timer() | |||
| try: | |||
| qinfo = REDIS_CONN.health(SVR_QUEUE_NAME) | |||
| res["redis"] = {"status": "green", "elapsed": "{:.1f}".format((timer() - st)*1000.), "pending": qinfo["pending"]} | |||
| except Exception as e: | |||
| res["redis"] = {"status": "red", "elapsed": "{:.1f}".format((timer() - st)*1000.), "error": str(e)} | |||
| return get_json_result(data=res) | |||
| @@ -14,6 +14,7 @@ | |||
| # limitations under the License. | |||
| # | |||
| import re | |||
| from copy import deepcopy | |||
| from api.db import LLMType | |||
| from api.db.db_models import Dialog, Conversation | |||
| @@ -71,7 +72,7 @@ def message_fit_in(msg, max_length=4000): | |||
| return max_length, msg | |||
| def chat(dialog, messages, **kwargs): | |||
| def chat(dialog, messages, stream=True, **kwargs): | |||
| assert messages[-1]["role"] == "user", "The last content of this conversation is not from user." | |||
| llm = LLMService.query(llm_name=dialog.llm_id) | |||
| if not llm: | |||
| @@ -82,7 +83,10 @@ def chat(dialog, messages, **kwargs): | |||
| else: max_tokens = llm[0].max_tokens | |||
| kbs = KnowledgebaseService.get_by_ids(dialog.kb_ids) | |||
| embd_nms = list(set([kb.embd_id for kb in kbs])) | |||
| assert len(embd_nms) == 1, "Knowledge bases use different embedding models." | |||
| if len(embd_nms) != 1: | |||
| if stream: | |||
| yield {"answer": "**ERROR**: Knowledge bases use different embedding models.", "reference": []} | |||
| return {"answer": "**ERROR**: Knowledge bases use different embedding models.", "reference": []} | |||
| questions = [m["content"] for m in messages if m["role"] == "user"] | |||
| embd_mdl = LLMBundle(dialog.tenant_id, LLMType.EMBEDDING, embd_nms[0]) | |||
| @@ -94,7 +98,9 @@ def chat(dialog, messages, **kwargs): | |||
| if field_map: | |||
| chat_logger.info("Use SQL to retrieval:{}".format(questions[-1])) | |||
| ans = use_sql(questions[-1], field_map, dialog.tenant_id, chat_mdl, prompt_config.get("quote", True)) | |||
| if ans: return ans | |||
| if ans: | |||
| yield ans | |||
| return | |||
| for p in prompt_config["parameters"]: | |||
| if p["key"] == "knowledge": | |||
| @@ -118,8 +124,9 @@ def chat(dialog, messages, **kwargs): | |||
| "{}->{}".format(" ".join(questions), "\n->".join(knowledges))) | |||
| if not knowledges and prompt_config.get("empty_response"): | |||
| return { | |||
| "answer": prompt_config["empty_response"], "reference": kbinfos} | |||
| if stream: | |||
| yield {"answer": prompt_config["empty_response"], "reference": kbinfos} | |||
| return {"answer": prompt_config["empty_response"], "reference": kbinfos} | |||
| kwargs["knowledge"] = "\n".join(knowledges) | |||
| gen_conf = dialog.llm_setting | |||
| @@ -130,33 +137,45 @@ def chat(dialog, messages, **kwargs): | |||
| gen_conf["max_tokens"] = min( | |||
| gen_conf["max_tokens"], | |||
| max_tokens - used_token_count) | |||
| answer = chat_mdl.chat( | |||
| prompt_config["system"].format( | |||
| **kwargs), msg, gen_conf) | |||
| chat_logger.info("User: {}|Assistant: {}".format( | |||
| msg[-1]["content"], answer)) | |||
| if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)): | |||
| answer, idx = retrievaler.insert_citations(answer, | |||
| [ck["content_ltks"] | |||
| for ck in kbinfos["chunks"]], | |||
| [ck["vector"] | |||
| for ck in kbinfos["chunks"]], | |||
| embd_mdl, | |||
| tkweight=1 - dialog.vector_similarity_weight, | |||
| vtweight=dialog.vector_similarity_weight) | |||
| idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx]) | |||
| recall_docs = [ | |||
| d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx] | |||
| if not recall_docs: recall_docs = kbinfos["doc_aggs"] | |||
| kbinfos["doc_aggs"] = recall_docs | |||
| for c in kbinfos["chunks"]: | |||
| if c.get("vector"): | |||
| del c["vector"] | |||
| if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api")>=0: | |||
| answer += " Please set LLM API-Key in 'User Setting -> Model Providers -> API-Key'" | |||
| return {"answer": answer, "reference": kbinfos} | |||
| def decorate_answer(answer): | |||
| nonlocal prompt_config, knowledges, kwargs, kbinfos | |||
| if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)): | |||
| answer, idx = retrievaler.insert_citations(answer, | |||
| [ck["content_ltks"] | |||
| for ck in kbinfos["chunks"]], | |||
| [ck["vector"] | |||
| for ck in kbinfos["chunks"]], | |||
| embd_mdl, | |||
| tkweight=1 - dialog.vector_similarity_weight, | |||
| vtweight=dialog.vector_similarity_weight) | |||
| idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx]) | |||
| recall_docs = [ | |||
| d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx] | |||
| if not recall_docs: recall_docs = kbinfos["doc_aggs"] | |||
| kbinfos["doc_aggs"] = recall_docs | |||
| refs = deepcopy(kbinfos) | |||
| for c in refs["chunks"]: | |||
| if c.get("vector"): | |||
| del c["vector"] | |||
| if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api")>=0: | |||
| answer += " Please set LLM API-Key in 'User Setting -> Model Providers -> API-Key'" | |||
| return {"answer": answer, "reference": refs} | |||
| if stream: | |||
| answer = "" | |||
| for ans in chat_mdl.chat_streamly(prompt_config["system"].format(**kwargs), msg, gen_conf): | |||
| answer = ans | |||
| yield {"answer": answer, "reference": {}} | |||
| yield decorate_answer(answer) | |||
| else: | |||
| answer = chat_mdl.chat( | |||
| prompt_config["system"].format( | |||
| **kwargs), msg, gen_conf) | |||
| chat_logger.info("User: {}|Assistant: {}".format( | |||
| msg[-1]["content"], answer)) | |||
| return decorate_answer(answer) | |||
| def use_sql(question, field_map, tenant_id, chat_mdl, quota=True): | |||
| @@ -43,7 +43,7 @@ class DocumentService(CommonService): | |||
| docs = cls.model.select().where( | |||
| (cls.model.kb_id == kb_id), | |||
| (fn.LOWER(cls.model.name).contains(keywords.lower())) | |||
| ) | |||
| ) | |||
| else: | |||
| docs = cls.model.select().where(cls.model.kb_id == kb_id) | |||
| count = docs.count() | |||
| @@ -75,7 +75,7 @@ class DocumentService(CommonService): | |||
| def delete(cls, doc): | |||
| e, kb = KnowledgebaseService.get_by_id(doc.kb_id) | |||
| if not KnowledgebaseService.update_by_id( | |||
| kb.id, {"doc_num": kb.doc_num - 1}): | |||
| kb.id, {"doc_num": max(0, kb.doc_num - 1)}): | |||
| raise RuntimeError("Database error (Knowledgebase)!") | |||
| return cls.delete_by_id(doc.id) | |||
| @@ -172,8 +172,18 @@ class LLMBundle(object): | |||
| def chat(self, system, history, gen_conf): | |||
| txt, used_tokens = self.mdl.chat(system, history, gen_conf) | |||
| if TenantLLMService.increase_usage( | |||
| if not TenantLLMService.increase_usage( | |||
| self.tenant_id, self.llm_type, used_tokens, self.llm_name): | |||
| database_logger.error( | |||
| "Can't update token usage for {}/CHAT".format(self.tenant_id)) | |||
| return txt | |||
| def chat_streamly(self, system, history, gen_conf): | |||
| for txt in self.mdl.chat_streamly(system, history, gen_conf): | |||
| if isinstance(txt, int): | |||
| if not TenantLLMService.increase_usage( | |||
| self.tenant_id, self.llm_type, txt, self.llm_name): | |||
| database_logger.error( | |||
| "Can't update token usage for {}/CHAT".format(self.tenant_id)) | |||
| return | |||
| yield txt | |||
| @@ -25,7 +25,6 @@ from flask import ( | |||
| from werkzeug.http import HTTP_STATUS_CODES | |||
| from api.utils import json_dumps | |||
| from api.versions import get_rag_version | |||
| from api.settings import RetCode | |||
| from api.settings import ( | |||
| REQUEST_MAX_WAIT_SEC, REQUEST_WAIT_SEC, | |||
| @@ -84,9 +83,6 @@ def request(**kwargs): | |||
| return sess.send(prepped, stream=stream, timeout=timeout) | |||
| rag_version = get_rag_version() or '' | |||
| def get_exponential_backoff_interval(retries, full_jitter=False): | |||
| """Calculate the exponential backoff wait time.""" | |||
| # Will be zero if factor equals 0 | |||
| @@ -20,7 +20,6 @@ from openai import OpenAI | |||
| import openai | |||
| from ollama import Client | |||
| from rag.nlp import is_english | |||
| from rag.utils import num_tokens_from_string | |||
| class Base(ABC): | |||
| @@ -44,6 +43,31 @@ class Base(ABC): | |||
| except openai.APIError as e: | |||
| return "**ERROR**: " + str(e), 0 | |||
| def chat_streamly(self, system, history, gen_conf): | |||
| if system: | |||
| history.insert(0, {"role": "system", "content": system}) | |||
| ans = "" | |||
| total_tokens = 0 | |||
| try: | |||
| response = self.client.chat.completions.create( | |||
| model=self.model_name, | |||
| messages=history, | |||
| stream=True, | |||
| **gen_conf) | |||
| for resp in response: | |||
| if not resp.choices[0].delta.content:continue | |||
| ans += resp.choices[0].delta.content | |||
| total_tokens += 1 | |||
| if resp.choices[0].finish_reason == "length": | |||
| ans += "...\nFor the content length reason, it stopped, continue?" if is_english( | |||
| [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?" | |||
| yield ans | |||
| except openai.APIError as e: | |||
| yield ans + "\n**ERROR**: " + str(e) | |||
| yield total_tokens | |||
| class GptTurbo(Base): | |||
| def __init__(self, key, model_name="gpt-3.5-turbo", base_url="https://api.openai.com/v1"): | |||
| @@ -97,6 +121,35 @@ class QWenChat(Base): | |||
| return "**ERROR**: " + response.message, tk_count | |||
| def chat_streamly(self, system, history, gen_conf): | |||
| from http import HTTPStatus | |||
| if system: | |||
| history.insert(0, {"role": "system", "content": system}) | |||
| ans = "" | |||
| try: | |||
| response = Generation.call( | |||
| self.model_name, | |||
| messages=history, | |||
| result_format='message', | |||
| stream=True, | |||
| **gen_conf | |||
| ) | |||
| tk_count = 0 | |||
| for resp in response: | |||
| if resp.status_code == HTTPStatus.OK: | |||
| ans = resp.output.choices[0]['message']['content'] | |||
| tk_count = resp.usage.total_tokens | |||
| if resp.output.choices[0].get("finish_reason", "") == "length": | |||
| ans += "...\nFor the content length reason, it stopped, continue?" if is_english( | |||
| [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?" | |||
| yield ans | |||
| else: | |||
| yield ans + "\n**ERROR**: " + resp.message if str(resp.message).find("Access")<0 else "Out of credit. Please set the API key in **settings > Model providers.**" | |||
| except Exception as e: | |||
| yield ans + "\n**ERROR**: " + str(e) | |||
| yield tk_count | |||
| class ZhipuChat(Base): | |||
| def __init__(self, key, model_name="glm-3-turbo", **kwargs): | |||
| @@ -122,6 +175,34 @@ class ZhipuChat(Base): | |||
| except Exception as e: | |||
| return "**ERROR**: " + str(e), 0 | |||
| def chat_streamly(self, system, history, gen_conf): | |||
| if system: | |||
| history.insert(0, {"role": "system", "content": system}) | |||
| if "presence_penalty" in gen_conf: del gen_conf["presence_penalty"] | |||
| if "frequency_penalty" in gen_conf: del gen_conf["frequency_penalty"] | |||
| ans = "" | |||
| try: | |||
| response = self.client.chat.completions.create( | |||
| model=self.model_name, | |||
| messages=history, | |||
| stream=True, | |||
| **gen_conf | |||
| ) | |||
| tk_count = 0 | |||
| for resp in response: | |||
| if not resp.choices[0].delta.content:continue | |||
| delta = resp.choices[0].delta.content | |||
| ans += delta | |||
| tk_count = resp.usage.total_tokens if response.usage else 0 | |||
| if resp.output.choices[0].finish_reason == "length": | |||
| ans += "...\nFor the content length reason, it stopped, continue?" if is_english( | |||
| [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?" | |||
| yield ans | |||
| except Exception as e: | |||
| yield ans + "\n**ERROR**: " + str(e) | |||
| yield tk_count | |||
| class OllamaChat(Base): | |||
| def __init__(self, key, model_name, **kwargs): | |||
| @@ -148,3 +229,28 @@ class OllamaChat(Base): | |||
| except Exception as e: | |||
| return "**ERROR**: " + str(e), 0 | |||
| def chat_streamly(self, system, history, gen_conf): | |||
| if system: | |||
| history.insert(0, {"role": "system", "content": system}) | |||
| options = {} | |||
| if "temperature" in gen_conf: options["temperature"] = gen_conf["temperature"] | |||
| if "max_tokens" in gen_conf: options["num_predict"] = gen_conf["max_tokens"] | |||
| if "top_p" in gen_conf: options["top_k"] = gen_conf["top_p"] | |||
| if "presence_penalty" in gen_conf: options["presence_penalty"] = gen_conf["presence_penalty"] | |||
| if "frequency_penalty" in gen_conf: options["frequency_penalty"] = gen_conf["frequency_penalty"] | |||
| ans = "" | |||
| try: | |||
| response = self.client.chat( | |||
| model=self.model_name, | |||
| messages=history, | |||
| stream=True, | |||
| options=options | |||
| ) | |||
| for resp in response: | |||
| if resp["done"]: | |||
| return resp["prompt_eval_count"] + resp["eval_count"] | |||
| ans = resp["message"]["content"] | |||
| yield ans | |||
| except Exception as e: | |||
| yield ans + "\n**ERROR**: " + str(e) | |||
| yield 0 | |||
| @@ -80,7 +80,7 @@ def set_progress(task_id, from_page=0, to_page=-1, | |||
| if to_page > 0: | |||
| if msg: | |||
| msg = f"Page({from_page+1}~{to_page+1}): " + msg | |||
| msg = f"Page({from_page + 1}~{to_page + 1}): " + msg | |||
| d = {"progress_msg": msg} | |||
| if prog is not None: | |||
| d["progress"] = prog | |||
| @@ -124,7 +124,7 @@ def get_minio_binary(bucket, name): | |||
| def build(row): | |||
| if row["size"] > DOC_MAXIMUM_SIZE: | |||
| set_progress(row["id"], prog=-1, msg="File size exceeds( <= %dMb )" % | |||
| (int(DOC_MAXIMUM_SIZE / 1024 / 1024))) | |||
| (int(DOC_MAXIMUM_SIZE / 1024 / 1024))) | |||
| return [] | |||
| callback = partial( | |||
| @@ -138,12 +138,12 @@ def build(row): | |||
| bucket, name = File2DocumentService.get_minio_address(doc_id=row["doc_id"]) | |||
| binary = get_minio_binary(bucket, name) | |||
| cron_logger.info( | |||
| "From minio({}) {}/{}".format(timer()-st, row["location"], row["name"])) | |||
| "From minio({}) {}/{}".format(timer() - st, row["location"], row["name"])) | |||
| cks = chunker.chunk(row["name"], binary=binary, from_page=row["from_page"], | |||
| to_page=row["to_page"], lang=row["language"], callback=callback, | |||
| kb_id=row["kb_id"], parser_config=row["parser_config"], tenant_id=row["tenant_id"]) | |||
| cron_logger.info( | |||
| "Chunkking({}) {}/{}".format(timer()-st, row["location"], row["name"])) | |||
| "Chunkking({}) {}/{}".format(timer() - st, row["location"], row["name"])) | |||
| except TimeoutError as e: | |||
| callback(-1, f"Internal server error: Fetch file timeout. Could you try it again.") | |||
| cron_logger.error( | |||
| @@ -173,7 +173,7 @@ def build(row): | |||
| d.update(ck) | |||
| md5 = hashlib.md5() | |||
| md5.update((ck["content_with_weight"] + | |||
| str(d["doc_id"])).encode("utf-8")) | |||
| str(d["doc_id"])).encode("utf-8")) | |||
| d["_id"] = md5.hexdigest() | |||
| d["create_time"] = str(datetime.datetime.now()).replace("T", " ")[:19] | |||
| d["create_timestamp_flt"] = datetime.datetime.now().timestamp() | |||
| @@ -261,7 +261,7 @@ def main(): | |||
| st = timer() | |||
| cks = build(r) | |||
| cron_logger.info("Build chunks({}): {:.2f}".format(r["name"], timer()-st)) | |||
| cron_logger.info("Build chunks({}): {}".format(r["name"], timer() - st)) | |||
| if cks is None: | |||
| continue | |||
| if not cks: | |||
| @@ -271,7 +271,7 @@ def main(): | |||
| ## set_progress(r["did"], -1, "ERROR: ") | |||
| callback( | |||
| msg="Finished slicing files(%d). Start to embedding the content." % | |||
| len(cks)) | |||
| len(cks)) | |||
| st = timer() | |||
| try: | |||
| tk_count = embedding(cks, embd_mdl, r["parser_config"], callback) | |||
| @@ -279,19 +279,19 @@ def main(): | |||
| callback(-1, "Embedding error:{}".format(str(e))) | |||
| cron_logger.error(str(e)) | |||
| tk_count = 0 | |||
| cron_logger.info("Embedding elapsed({}): {:.2f}".format(r["name"], timer()-st)) | |||
| cron_logger.info("Embedding elapsed({}): {:.2f}".format(r["name"], timer() - st)) | |||
| callback(msg="Finished embedding({:.2f})! Start to build index!".format(timer()-st)) | |||
| callback(msg="Finished embedding({:.2f})! Start to build index!".format(timer() - st)) | |||
| init_kb(r) | |||
| chunk_count = len(set([c["_id"] for c in cks])) | |||
| st = timer() | |||
| es_r = "" | |||
| for b in range(0, len(cks), 32): | |||
| es_r = ELASTICSEARCH.bulk(cks[b:b+32], search.index_name(r["tenant_id"])) | |||
| es_r = ELASTICSEARCH.bulk(cks[b:b + 32], search.index_name(r["tenant_id"])) | |||
| if b % 128 == 0: | |||
| callback(prog=0.8 + 0.1 * (b + 1) / len(cks), msg="") | |||
| cron_logger.info("Indexing elapsed({}): {:.2f}".format(r["name"], timer()-st)) | |||
| cron_logger.info("Indexing elapsed({}): {:.2f}".format(r["name"], timer() - st)) | |||
| if es_r: | |||
| callback(-1, "Index failure!") | |||
| ELASTICSEARCH.deleteByQuery( | |||
| @@ -307,8 +307,7 @@ def main(): | |||
| r["doc_id"], r["kb_id"], tk_count, chunk_count, 0) | |||
| cron_logger.info( | |||
| "Chunk doc({}), token({}), chunks({}), elapsed:{:.2f}".format( | |||
| r["id"], tk_count, len(cks), timer()-st)) | |||
| r["id"], tk_count, len(cks), timer() - st)) | |||
| if __name__ == "__main__": | |||
| @@ -43,6 +43,9 @@ class ESConnection: | |||
| v = v["number"].split(".")[0] | |||
| return int(v) >= 7 | |||
| def health(self): | |||
| return dict(self.es.cluster.health()) | |||
| def upsert(self, df, idxnm=""): | |||
| res = [] | |||
| for d in df: | |||
| @@ -34,6 +34,16 @@ class RAGFlowMinio(object): | |||
| del self.conn | |||
| self.conn = None | |||
| def health(self): | |||
| bucket, fnm, binary = "_t@@@1", "_t@@@1", b"_t@@@1" | |||
| if not self.conn.bucket_exists(bucket): | |||
| self.conn.make_bucket(bucket) | |||
| r = self.conn.put_object(bucket, fnm, | |||
| BytesIO(binary), | |||
| len(binary) | |||
| ) | |||
| return r | |||
| def put(self, bucket, fnm, binary): | |||
| for _ in range(3): | |||
| try: | |||
| @@ -44,6 +44,10 @@ class RedisDB: | |||
| logging.warning("Redis can't be connected.") | |||
| return self.REDIS | |||
| def health(self, queue_name): | |||
| self.REDIS.ping() | |||
| return self.REDIS.xinfo_groups(queue_name)[0] | |||
| def is_alive(self): | |||
| return self.REDIS is not None | |||