### What problem does this PR solve? ### Type of change - [x] Refactoringtags/v0.10.0
| @@ -26,7 +26,7 @@ from api.db.db_models import APIToken, API4Conversation, Task, File | |||
| from api.db.services import duplicate_name | |||
| from api.db.services.api_service import APITokenService, API4ConversationService | |||
| from api.db.services.dialog_service import DialogService, chat | |||
| from api.db.services.document_service import DocumentService | |||
| from api.db.services.document_service import DocumentService, doc_upload_and_parse | |||
| from api.db.services.file2document_service import File2DocumentService | |||
| from api.db.services.file_service import FileService | |||
| from api.db.services.knowledgebase_service import KnowledgebaseService | |||
| @@ -470,6 +470,29 @@ def upload(): | |||
| return get_json_result(data=doc_result.to_json()) | |||
| @manager.route('/document/upload_and_parse', methods=['POST']) | |||
| @validate_request("conversation_id") | |||
| def upload_parse(): | |||
| token = request.headers.get('Authorization').split()[1] | |||
| objs = APIToken.query(token=token) | |||
| if not objs: | |||
| return get_json_result( | |||
| data=False, retmsg='Token is not valid!"', retcode=RetCode.AUTHENTICATION_ERROR) | |||
| if 'file' not in request.files: | |||
| return get_json_result( | |||
| data=False, retmsg='No file part!', retcode=RetCode.ARGUMENT_ERROR) | |||
| file_objs = request.files.getlist('file') | |||
| for file_obj in file_objs: | |||
| if file_obj.filename == '': | |||
| return get_json_result( | |||
| data=False, retmsg='No file selected!', retcode=RetCode.ARGUMENT_ERROR) | |||
| doc_ids = doc_upload_and_parse(request.form.get("conversation_id"), file_objs, objs[0].tenant_id) | |||
| return get_json_result(data=doc_ids) | |||
| @manager.route('/list_chunks', methods=['POST']) | |||
| # @login_required | |||
| def list_chunks(): | |||
| @@ -560,7 +583,6 @@ def document_rm(): | |||
| tenant_id = objs[0].tenant_id | |||
| req = request.json | |||
| doc_ids = [] | |||
| try: | |||
| doc_ids = [DocumentService.get_doc_id_by_doc_name(doc_name) for doc_name in req.get("doc_names", [])] | |||
| for doc_id in req.get("doc_ids", []): | |||
| @@ -45,7 +45,7 @@ from api.db.services.knowledgebase_service import KnowledgebaseService | |||
| from api.utils.api_utils import server_error_response, get_data_error_result, validate_request | |||
| from api.utils import get_uuid | |||
| from api.db import FileType, TaskStatus, ParserType, FileSource, LLMType | |||
| from api.db.services.document_service import DocumentService | |||
| from api.db.services.document_service import DocumentService, doc_upload_and_parse | |||
| from api.settings import RetCode, stat_logger | |||
| from api.utils.api_utils import get_json_result | |||
| from rag.utils.minio_conn import MINIO | |||
| @@ -75,7 +75,7 @@ def upload(): | |||
| if not e: | |||
| raise LookupError("Can't find this knowledgebase!") | |||
| err, _ = FileService.upload_document(kb, file_objs) | |||
| err, _ = FileService.upload_document(kb, file_objs, current_user.id) | |||
| if err: | |||
| return get_json_result( | |||
| data=False, retmsg="\n".join(err), retcode=RetCode.SERVER_ERROR) | |||
| @@ -212,7 +212,7 @@ def docinfos(): | |||
| @manager.route('/thumbnails', methods=['GET']) | |||
| @login_required | |||
| #@login_required | |||
| def thumbnails(): | |||
| doc_ids = request.args.get("doc_ids").split(",") | |||
| if not doc_ids: | |||
| @@ -460,7 +460,6 @@ def get_image(image_id): | |||
| @login_required | |||
| @validate_request("conversation_id") | |||
| def upload_and_parse(): | |||
| from rag.app import presentation, picture, naive, audio, email | |||
| if 'file' not in request.files: | |||
| return get_json_result( | |||
| data=False, retmsg='No file part!', retcode=RetCode.ARGUMENT_ERROR) | |||
| @@ -471,124 +470,6 @@ def upload_and_parse(): | |||
| return get_json_result( | |||
| data=False, retmsg='No file selected!', retcode=RetCode.ARGUMENT_ERROR) | |||
| e, conv = ConversationService.get_by_id(request.form.get("conversation_id")) | |||
| if not e: | |||
| return get_data_error_result(retmsg="Conversation not found!") | |||
| e, dia = DialogService.get_by_id(conv.dialog_id) | |||
| kb_id = dia.kb_ids[0] | |||
| e, kb = KnowledgebaseService.get_by_id(kb_id) | |||
| if not e: | |||
| raise LookupError("Can't find this knowledgebase!") | |||
| idxnm = search.index_name(kb.tenant_id) | |||
| if not ELASTICSEARCH.indexExist(idxnm): | |||
| ELASTICSEARCH.createIdx(idxnm, json.load( | |||
| open(os.path.join(get_project_base_directory(), "conf", "mapping.json"), "r"))) | |||
| embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING, llm_name=kb.embd_id, lang=kb.language) | |||
| err, files = FileService.upload_document(kb, file_objs) | |||
| if err: | |||
| return get_json_result( | |||
| data=False, retmsg="\n".join(err), retcode=RetCode.SERVER_ERROR) | |||
| def dummy(prog=None, msg=""): | |||
| pass | |||
| FACTORY = { | |||
| ParserType.PRESENTATION.value: presentation, | |||
| ParserType.PICTURE.value: picture, | |||
| ParserType.AUDIO.value: audio, | |||
| ParserType.EMAIL.value: email | |||
| } | |||
| parser_config = {"chunk_token_num": 4096, "delimiter": "\n!?;。;!?", "layout_recognize": False} | |||
| exe = ThreadPoolExecutor(max_workers=12) | |||
| threads = [] | |||
| for d, blob in files: | |||
| kwargs = { | |||
| "callback": dummy, | |||
| "parser_config": parser_config, | |||
| "from_page": 0, | |||
| "to_page": 100000, | |||
| "tenant_id": kb.tenant_id, | |||
| "lang": kb.language | |||
| } | |||
| threads.append(exe.submit(FACTORY.get(d["parser_id"], naive).chunk, d["name"], blob, **kwargs)) | |||
| doc_ids = doc_upload_and_parse(request.form.get("conversation_id"), file_objs, current_user.id) | |||
| for (docinfo,_), th in zip(files, threads): | |||
| docs = [] | |||
| doc = { | |||
| "doc_id": docinfo["id"], | |||
| "kb_id": [kb.id] | |||
| } | |||
| for ck in th.result(): | |||
| d = deepcopy(doc) | |||
| d.update(ck) | |||
| md5 = hashlib.md5() | |||
| md5.update((ck["content_with_weight"] + | |||
| str(d["doc_id"])).encode("utf-8")) | |||
| d["_id"] = md5.hexdigest() | |||
| d["create_time"] = str(datetime.datetime.now()).replace("T", " ")[:19] | |||
| d["create_timestamp_flt"] = datetime.datetime.now().timestamp() | |||
| if not d.get("image"): | |||
| docs.append(d) | |||
| continue | |||
| output_buffer = BytesIO() | |||
| if isinstance(d["image"], bytes): | |||
| output_buffer = BytesIO(d["image"]) | |||
| else: | |||
| d["image"].save(output_buffer, format='JPEG') | |||
| MINIO.put(kb.id, d["_id"], output_buffer.getvalue()) | |||
| d["img_id"] = "{}-{}".format(kb.id, d["_id"]) | |||
| del d["image"] | |||
| docs.append(d) | |||
| parser_ids = {d["id"]: d["parser_id"] for d, _ in files} | |||
| docids = [d["id"] for d, _ in files] | |||
| chunk_counts = {id: 0 for id in docids} | |||
| token_counts = {id: 0 for id in docids} | |||
| es_bulk_size = 64 | |||
| def embedding(doc_id, cnts, batch_size=16): | |||
| nonlocal embd_mdl, chunk_counts, token_counts | |||
| vects = [] | |||
| for i in range(0, len(cnts), batch_size): | |||
| vts, c = embd_mdl.encode(cnts[i: i + batch_size]) | |||
| vects.extend(vts.tolist()) | |||
| chunk_counts[doc_id] += len(cnts[i:i + batch_size]) | |||
| token_counts[doc_id] += c | |||
| return vects | |||
| _, tenant = TenantService.get_by_id(kb.tenant_id) | |||
| llm_bdl = LLMBundle(kb.tenant_id, LLMType.CHAT, tenant.llm_id) | |||
| for doc_id in docids: | |||
| cks = [c for c in docs if c["doc_id"] == doc_id] | |||
| if False and parser_ids[doc_id] != ParserType.PICTURE.value: | |||
| mindmap = MindMapExtractor(llm_bdl) | |||
| try: | |||
| mind_map = json.dumps(mindmap([c["content_with_weight"] for c in docs if c["doc_id"] == doc_id]).output, ensure_ascii=False, indent=2) | |||
| if len(mind_map) < 32: raise Exception("Few content: "+mind_map) | |||
| cks.append({ | |||
| "doc_id": doc_id, | |||
| "kb_id": [kb.id], | |||
| "content_with_weight": mind_map, | |||
| "knowledge_graph_kwd": "mind_map" | |||
| }) | |||
| except Exception as e: | |||
| stat_logger.error("Mind map generation error:", traceback.format_exc()) | |||
| vects = embedding(doc_id, [c["content_with_weight"] for c in cks]) | |||
| assert len(cks) == len(vects) | |||
| for i, d in enumerate(cks): | |||
| v = vects[i] | |||
| d["q_%d_vec" % len(v)] = v | |||
| for b in range(0, len(cks), es_bulk_size): | |||
| ELASTICSEARCH.bulk(cks[b:b + es_bulk_size], idxnm) | |||
| DocumentService.increment_chunk_num( | |||
| doc_id, kb.id, token_counts[doc_id], chunk_counts[doc_id], 0) | |||
| return get_json_result(data=[d["id"] for d,_ in files]) | |||
| return get_json_result(data=doc_ids) | |||
| @@ -13,20 +13,29 @@ | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # | |||
| import hashlib | |||
| import json | |||
| import os | |||
| import random | |||
| from concurrent.futures import ThreadPoolExecutor | |||
| from copy import deepcopy | |||
| from datetime import datetime | |||
| from io import BytesIO | |||
| from elasticsearch_dsl import Q | |||
| from peewee import fn | |||
| from api.db.db_utils import bulk_insert_into_db | |||
| from api.settings import stat_logger | |||
| from api.utils import current_timestamp, get_format_time, get_uuid | |||
| from api.utils.file_utils import get_project_base_directory | |||
| from graphrag.mind_map_extractor import MindMapExtractor | |||
| from rag.settings import SVR_QUEUE_NAME | |||
| from rag.utils.es_conn import ELASTICSEARCH | |||
| from rag.utils.minio_conn import MINIO | |||
| from rag.nlp import search | |||
| from api.db import FileType, TaskStatus, ParserType | |||
| from api.db import FileType, TaskStatus, ParserType, LLMType | |||
| from api.db.db_models import DB, Knowledgebase, Tenant, Task | |||
| from api.db.db_models import Document | |||
| from api.db.services.common_service import CommonService | |||
| @@ -380,3 +389,136 @@ def queue_raptor_tasks(doc): | |||
| bulk_insert_into_db(Task, [task], True) | |||
| task["type"] = "raptor" | |||
| assert REDIS_CONN.queue_product(SVR_QUEUE_NAME, message=task), "Can't access Redis. Please check the Redis' status." | |||
| def doc_upload_and_parse(conversation_id, file_objs, user_id): | |||
| from rag.app import presentation, picture, naive, audio, email | |||
| from api.db.services.dialog_service import ConversationService, DialogService | |||
| from api.db.services.file_service import FileService | |||
| from api.db.services.llm_service import LLMBundle | |||
| from api.db.services.user_service import TenantService | |||
| from api.db.services.api_service import API4ConversationService | |||
| e, conv = ConversationService.get_by_id(conversation_id) | |||
| if not e: | |||
| e, conv = API4ConversationService.get_by_id(conversation_id) | |||
| assert e, "Conversation not found!" | |||
| e, dia = DialogService.get_by_id(conv.dialog_id) | |||
| kb_id = dia.kb_ids[0] | |||
| e, kb = KnowledgebaseService.get_by_id(kb_id) | |||
| if not e: | |||
| raise LookupError("Can't find this knowledgebase!") | |||
| idxnm = search.index_name(kb.tenant_id) | |||
| if not ELASTICSEARCH.indexExist(idxnm): | |||
| ELASTICSEARCH.createIdx(idxnm, json.load( | |||
| open(os.path.join(get_project_base_directory(), "conf", "mapping.json"), "r"))) | |||
| embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING, llm_name=kb.embd_id, lang=kb.language) | |||
| err, files = FileService.upload_document(kb, file_objs, user_id) | |||
| assert not err, "\n".join(err) | |||
| def dummy(prog=None, msg=""): | |||
| pass | |||
| FACTORY = { | |||
| ParserType.PRESENTATION.value: presentation, | |||
| ParserType.PICTURE.value: picture, | |||
| ParserType.AUDIO.value: audio, | |||
| ParserType.EMAIL.value: email | |||
| } | |||
| parser_config = {"chunk_token_num": 4096, "delimiter": "\n!?;。;!?", "layout_recognize": False} | |||
| exe = ThreadPoolExecutor(max_workers=12) | |||
| threads = [] | |||
| for d, blob in files: | |||
| kwargs = { | |||
| "callback": dummy, | |||
| "parser_config": parser_config, | |||
| "from_page": 0, | |||
| "to_page": 100000, | |||
| "tenant_id": kb.tenant_id, | |||
| "lang": kb.language | |||
| } | |||
| threads.append(exe.submit(FACTORY.get(d["parser_id"], naive).chunk, d["name"], blob, **kwargs)) | |||
| for (docinfo, _), th in zip(files, threads): | |||
| docs = [] | |||
| doc = { | |||
| "doc_id": docinfo["id"], | |||
| "kb_id": [kb.id] | |||
| } | |||
| for ck in th.result(): | |||
| d = deepcopy(doc) | |||
| d.update(ck) | |||
| md5 = hashlib.md5() | |||
| md5.update((ck["content_with_weight"] + | |||
| str(d["doc_id"])).encode("utf-8")) | |||
| d["_id"] = md5.hexdigest() | |||
| d["create_time"] = str(datetime.now()).replace("T", " ")[:19] | |||
| d["create_timestamp_flt"] = datetime.now().timestamp() | |||
| if not d.get("image"): | |||
| docs.append(d) | |||
| continue | |||
| output_buffer = BytesIO() | |||
| if isinstance(d["image"], bytes): | |||
| output_buffer = BytesIO(d["image"]) | |||
| else: | |||
| d["image"].save(output_buffer, format='JPEG') | |||
| MINIO.put(kb.id, d["_id"], output_buffer.getvalue()) | |||
| d["img_id"] = "{}-{}".format(kb.id, d["_id"]) | |||
| del d["image"] | |||
| docs.append(d) | |||
| parser_ids = {d["id"]: d["parser_id"] for d, _ in files} | |||
| docids = [d["id"] for d, _ in files] | |||
| chunk_counts = {id: 0 for id in docids} | |||
| token_counts = {id: 0 for id in docids} | |||
| es_bulk_size = 64 | |||
| def embedding(doc_id, cnts, batch_size=16): | |||
| nonlocal embd_mdl, chunk_counts, token_counts | |||
| vects = [] | |||
| for i in range(0, len(cnts), batch_size): | |||
| vts, c = embd_mdl.encode(cnts[i: i + batch_size]) | |||
| vects.extend(vts.tolist()) | |||
| chunk_counts[doc_id] += len(cnts[i:i + batch_size]) | |||
| token_counts[doc_id] += c | |||
| return vects | |||
| _, tenant = TenantService.get_by_id(kb.tenant_id) | |||
| llm_bdl = LLMBundle(kb.tenant_id, LLMType.CHAT, tenant.llm_id) | |||
| for doc_id in docids: | |||
| cks = [c for c in docs if c["doc_id"] == doc_id] | |||
| if parser_ids[doc_id] != ParserType.PICTURE.value: | |||
| mindmap = MindMapExtractor(llm_bdl) | |||
| try: | |||
| mind_map = json.dumps(mindmap([c["content_with_weight"] for c in docs if c["doc_id"] == doc_id]).output, | |||
| ensure_ascii=False, indent=2) | |||
| if len(mind_map) < 32: raise Exception("Few content: " + mind_map) | |||
| cks.append({ | |||
| "id": get_uuid(), | |||
| "doc_id": doc_id, | |||
| "kb_id": [kb.id], | |||
| "content_with_weight": mind_map, | |||
| "knowledge_graph_kwd": "mind_map" | |||
| }) | |||
| except Exception as e: | |||
| stat_logger.error("Mind map generation error:", traceback.format_exc()) | |||
| vects = embedding(doc_id, [c["content_with_weight"] for c in cks]) | |||
| assert len(cks) == len(vects) | |||
| for i, d in enumerate(cks): | |||
| v = vects[i] | |||
| d["q_%d_vec" % len(v)] = v | |||
| for b in range(0, len(cks), es_bulk_size): | |||
| ELASTICSEARCH.bulk(cks[b:b + es_bulk_size], idxnm) | |||
| DocumentService.increment_chunk_num( | |||
| doc_id, kb.id, token_counts[doc_id], chunk_counts[doc_id], 0) | |||
| return [d["id"] for d,_ in files] | |||
| @@ -327,11 +327,11 @@ class FileService(CommonService): | |||
| @classmethod | |||
| @DB.connection_context() | |||
| def upload_document(self, kb, file_objs): | |||
| root_folder = self.get_root_folder(current_user.id) | |||
| def upload_document(self, kb, file_objs, user_id): | |||
| root_folder = self.get_root_folder(user_id) | |||
| pf_id = root_folder["id"] | |||
| self.init_knowledgebase_docs(pf_id, current_user.id) | |||
| kb_root_folder = self.get_kb_folder(current_user.id) | |||
| self.init_knowledgebase_docs(pf_id, user_id) | |||
| kb_root_folder = self.get_kb_folder(user_id) | |||
| kb_folder = self.new_a_file_from_kb(kb.tenant_id, kb.name, kb_root_folder["id"]) | |||
| err, files = [], [] | |||
| @@ -359,7 +359,7 @@ class FileService(CommonService): | |||
| "kb_id": kb.id, | |||
| "parser_id": kb.parser_id, | |||
| "parser_config": kb.parser_config, | |||
| "created_by": current_user.id, | |||
| "created_by": user_id, | |||
| "type": filetype, | |||
| "name": filename, | |||
| "location": location, | |||