### What problem does this PR solve? #882 ### Type of change - [x] New Feature (non-breaking change which adds functionality)tags/v0.7.0
| @@ -60,7 +60,8 @@ def status(): | |||
| st = timer() | |||
| try: | |||
| qinfo = REDIS_CONN.health(SVR_QUEUE_NAME) | |||
| res["redis"] = {"status": "green", "elapsed": "{:.1f}".format((timer() - st)*1000.), "pending": qinfo["pending"]} | |||
| res["redis"] = {"status": "green", "elapsed": "{:.1f}".format((timer() - st)*1000.), | |||
| "pending": qinfo.get("pending", 0)} | |||
| except Exception as e: | |||
| res["redis"] = {"status": "red", "elapsed": "{:.1f}".format((timer() - st)*1000.), "error": str(e)} | |||
| @@ -18,8 +18,10 @@ from datetime import datetime | |||
| from elasticsearch_dsl import Q | |||
| from peewee import fn | |||
| from api.db.db_utils import bulk_insert_into_db | |||
| from api.settings import stat_logger | |||
| from api.utils import current_timestamp, get_format_time | |||
| from api.utils import current_timestamp, get_format_time, get_uuid | |||
| from rag.settings import SVR_QUEUE_NAME | |||
| from rag.utils.es_conn import ELASTICSEARCH | |||
| from rag.utils.minio_conn import MINIO | |||
| from rag.nlp import search | |||
| @@ -30,6 +32,7 @@ from api.db.db_models import Document | |||
| from api.db.services.common_service import CommonService | |||
| from api.db.services.knowledgebase_service import KnowledgebaseService | |||
| from api.db import StatusEnum | |||
| from rag.utils.redis_conn import REDIS_CONN | |||
| class DocumentService(CommonService): | |||
| @@ -110,7 +113,7 @@ class DocumentService(CommonService): | |||
| @classmethod | |||
| @DB.connection_context() | |||
| def get_unfinished_docs(cls): | |||
| fields = [cls.model.id, cls.model.process_begin_at] | |||
| fields = [cls.model.id, cls.model.process_begin_at, cls.model.parser_config, cls.model.progress_msg] | |||
| docs = cls.model.select(*fields) \ | |||
| .where( | |||
| cls.model.status == StatusEnum.VALID.value, | |||
| @@ -260,7 +263,12 @@ class DocumentService(CommonService): | |||
| prg = -1 | |||
| status = TaskStatus.FAIL.value | |||
| elif finished: | |||
| status = TaskStatus.DONE.value | |||
| if d["parser_config"].get("raptor") and d["progress_msg"].lower().find(" raptor")<0: | |||
| queue_raptor_tasks(d) | |||
| prg *= 0.98 | |||
| msg.append("------ RAPTOR -------") | |||
| else: | |||
| status = TaskStatus.DONE.value | |||
| msg = "\n".join(msg) | |||
| info = { | |||
| @@ -282,3 +290,19 @@ class DocumentService(CommonService): | |||
| return len(cls.model.select(cls.model.id).where( | |||
| cls.model.kb_id == kb_id).dicts()) | |||
| def queue_raptor_tasks(doc): | |||
| def new_task(): | |||
| nonlocal doc | |||
| return { | |||
| "id": get_uuid(), | |||
| "doc_id": doc["id"], | |||
| "from_page": 0, | |||
| "to_page": -1, | |||
| "progress_msg": "Start to do RAPTOR (Recursive Abstractive Processing For Tree-Organized Retrieval)." | |||
| } | |||
| task = new_task() | |||
| bulk_insert_into_db(Task, [task], True) | |||
| task["type"] = "raptor" | |||
| assert REDIS_CONN.queue_product(SVR_QUEUE_NAME, message=task), "Can't access Redis. Please check the Redis' status." | |||
| @@ -155,6 +155,10 @@ class LLMBundle(object): | |||
| tenant_id, llm_type, llm_name, lang=lang) | |||
| assert self.mdl, "Can't find mole for {}/{}/{}".format( | |||
| tenant_id, llm_type, llm_name) | |||
| self.max_length = 512 | |||
| for lm in LLMService.query(llm_name=llm_name): | |||
| self.max_length = lm.max_tokens | |||
| break | |||
| def encode(self, texts: list, batch_size=32): | |||
| emd, used_tokens = self.mdl.encode(texts, batch_size) | |||
| @@ -53,6 +53,7 @@ class TaskService(CommonService): | |||
| Knowledgebase.embd_id, | |||
| Tenant.img2txt_id, | |||
| Tenant.asr_id, | |||
| Tenant.llm_id, | |||
| cls.model.update_time] | |||
| docs = cls.model.select(*fields) \ | |||
| .join(Document, on=(cls.model.doc_id == Document.id)) \ | |||
| @@ -159,4 +160,4 @@ def queue_tasks(doc, bucket, name): | |||
| DocumentService.begin2parse(doc["id"]) | |||
| for t in tsks: | |||
| assert REDIS_CONN.queue_product(SVR_QUEUE_NAME, message=t), "Can't access Redis. Please check the Redis' status." | |||
| assert REDIS_CONN.queue_product(SVR_QUEUE_NAME, message=t), "Can't access Redis. Please check the Redis' status." | |||
| @@ -57,8 +57,7 @@ class Base(ABC): | |||
| stream=True, | |||
| **gen_conf) | |||
| for resp in response: | |||
| if len(resp.choices) == 0:continue | |||
| if not resp.choices[0].delta.content:continue | |||
| if not resp.choices or not resp.choices[0].delta.content:continue | |||
| ans += resp.choices[0].delta.content | |||
| total_tokens += 1 | |||
| if resp.choices[0].finish_reason == "length": | |||
| @@ -379,7 +378,7 @@ class VolcEngineChat(Base): | |||
| ans += resp.choices[0].message.content | |||
| yield ans | |||
| if resp.choices[0].finish_reason == "stop": | |||
| return resp.usage.total_tokens | |||
| yield resp.usage.total_tokens | |||
| except Exception as e: | |||
| yield ans + "\n**ERROR**: " + str(e) | |||
| @@ -0,0 +1,114 @@ | |||
| # | |||
| # Copyright 2024 The InfiniFlow Authors. All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # | |||
| import re | |||
| import traceback | |||
| from concurrent.futures import ThreadPoolExecutor, ALL_COMPLETED, wait | |||
| from threading import Lock | |||
| from typing import Tuple | |||
| import umap | |||
| import numpy as np | |||
| from sklearn.mixture import GaussianMixture | |||
| from rag.utils import num_tokens_from_string, truncate | |||
| class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval: | |||
| def __init__(self, max_cluster, llm_model, embd_model, prompt, max_token=256, threshold=0.1): | |||
| self._max_cluster = max_cluster | |||
| self._llm_model = llm_model | |||
| self._embd_model = embd_model | |||
| self._threshold = threshold | |||
| self._prompt = prompt | |||
| self._max_token = max_token | |||
| def _get_optimal_clusters(self, embeddings: np.ndarray, random_state:int): | |||
| max_clusters = min(self._max_cluster, len(embeddings)) | |||
| n_clusters = np.arange(1, max_clusters) | |||
| bics = [] | |||
| for n in n_clusters: | |||
| gm = GaussianMixture(n_components=n, random_state=random_state) | |||
| gm.fit(embeddings) | |||
| bics.append(gm.bic(embeddings)) | |||
| optimal_clusters = n_clusters[np.argmin(bics)] | |||
| return optimal_clusters | |||
| def __call__(self, chunks: Tuple[str, np.ndarray], random_state, callback=None): | |||
| layers = [(0, len(chunks))] | |||
| start, end = 0, len(chunks) | |||
| if len(chunks) <= 1: return | |||
| def summarize(ck_idx, lock): | |||
| nonlocal chunks | |||
| try: | |||
| texts = [chunks[i][0] for i in ck_idx] | |||
| len_per_chunk = int((self._llm_model.max_length - self._max_token)/len(texts)) | |||
| cluster_content = "\n".join([truncate(t, max(1, len_per_chunk)) for t in texts]) | |||
| cnt = self._llm_model.chat("You're a helpful assistant.", | |||
| [{"role": "user", "content": self._prompt.format(cluster_content=cluster_content)}], | |||
| {"temperature": 0.3, "max_tokens": self._max_token} | |||
| ) | |||
| cnt = re.sub("(······\n由于长度的原因,回答被截断了,要继续吗?|For the content length reason, it stopped, continue?)", "", cnt) | |||
| print("SUM:", cnt) | |||
| embds, _ = self._embd_model.encode([cnt]) | |||
| with lock: | |||
| chunks.append((cnt, embds[0])) | |||
| except Exception as e: | |||
| print(e, flush=True) | |||
| traceback.print_stack(e) | |||
| return e | |||
| labels = [] | |||
| while end - start > 1: | |||
| embeddings = [embd for _, embd in chunks[start: end]] | |||
| if len(embeddings) == 2: | |||
| summarize([start, start+1], Lock()) | |||
| if callback: | |||
| callback(msg="Cluster one layer: {} -> {}".format(end-start, len(chunks)-end)) | |||
| labels.extend([0,0]) | |||
| layers.append((end, len(chunks))) | |||
| start = end | |||
| end = len(chunks) | |||
| continue | |||
| n_neighbors = int((len(embeddings) - 1) ** 0.8) | |||
| reduced_embeddings = umap.UMAP( | |||
| n_neighbors=max(2, n_neighbors), n_components=min(12, len(embeddings)-2), metric="cosine" | |||
| ).fit_transform(embeddings) | |||
| n_clusters = self._get_optimal_clusters(reduced_embeddings, random_state) | |||
| if n_clusters == 1: | |||
| lbls = [0 for _ in range(len(reduced_embeddings))] | |||
| else: | |||
| gm = GaussianMixture(n_components=n_clusters, random_state=random_state) | |||
| gm.fit(reduced_embeddings) | |||
| probs = gm.predict_proba(reduced_embeddings) | |||
| lbls = [np.where(prob > self._threshold)[0] for prob in probs] | |||
| lock = Lock() | |||
| with ThreadPoolExecutor(max_workers=12) as executor: | |||
| threads = [] | |||
| for c in range(n_clusters): | |||
| ck_idx = [i+start for i in range(len(lbls)) if lbls[i] == c] | |||
| threads.append(executor.submit(summarize, ck_idx, lock)) | |||
| wait(threads, return_when=ALL_COMPLETED) | |||
| print([t.result() for t in threads]) | |||
| assert len(chunks) - end == n_clusters, "{} vs. {}".format(len(chunks) - end, n_clusters) | |||
| labels.extend(lbls) | |||
| layers.append((end, len(chunks))) | |||
| if callback: | |||
| callback(msg="Cluster one layer: {} -> {}".format(end-start, len(chunks)-end)) | |||
| start = end | |||
| end = len(chunks) | |||
| @@ -26,20 +26,22 @@ import traceback | |||
| from functools import partial | |||
| from api.db.services.file2document_service import File2DocumentService | |||
| from api.settings import retrievaler | |||
| from rag.raptor import RecursiveAbstractiveProcessing4TreeOrganizedRetrieval as Raptor | |||
| from rag.utils.minio_conn import MINIO | |||
| from api.db.db_models import close_connection | |||
| from rag.settings import database_logger, SVR_QUEUE_NAME | |||
| from rag.settings import cron_logger, DOC_MAXIMUM_SIZE | |||
| from multiprocessing import Pool | |||
| import numpy as np | |||
| from elasticsearch_dsl import Q | |||
| from elasticsearch_dsl import Q, Search | |||
| from multiprocessing.context import TimeoutError | |||
| from api.db.services.task_service import TaskService | |||
| from rag.utils.es_conn import ELASTICSEARCH | |||
| from timeit import default_timer as timer | |||
| from rag.utils import rmSpace, findMaxTm | |||
| from rag.utils import rmSpace, findMaxTm, num_tokens_from_string | |||
| from rag.nlp import search | |||
| from rag.nlp import search, rag_tokenizer | |||
| from io import BytesIO | |||
| import pandas as pd | |||
| @@ -114,6 +116,8 @@ def collect(): | |||
| tasks = TaskService.get_tasks(msg["id"]) | |||
| assert tasks, "{} empty task!".format(msg["id"]) | |||
| tasks = pd.DataFrame(tasks) | |||
| if msg.get("type", "") == "raptor": | |||
| tasks["task_type"] = "raptor" | |||
| return tasks | |||
| @@ -245,6 +249,47 @@ def embedding(docs, mdl, parser_config={}, callback=None): | |||
| return tk_count | |||
| def run_raptor(row, chat_mdl, embd_mdl, callback=None): | |||
| vts, _ = embd_mdl.encode(["ok"]) | |||
| vctr_nm = "q_%d_vec"%len(vts[0]) | |||
| chunks = [] | |||
| for d in retrievaler.chunk_list(row["doc_id"], row["tenant_id"], fields=["content_with_weight", vctr_nm]): | |||
| chunks.append((d["content_with_weight"], np.array(d[vctr_nm]))) | |||
| raptor = Raptor( | |||
| row["parser_config"]["raptor"].get("max_cluster", 64), | |||
| chat_mdl, | |||
| embd_mdl, | |||
| row["parser_config"]["raptor"]["prompt"], | |||
| row["parser_config"]["raptor"]["max_token"], | |||
| row["parser_config"]["raptor"]["threshold"] | |||
| ) | |||
| original_length = len(chunks) | |||
| raptor(chunks, row["parser_config"]["raptor"]["random_seed"], callback) | |||
| doc = { | |||
| "doc_id": row["doc_id"], | |||
| "kb_id": [str(row["kb_id"])], | |||
| "docnm_kwd": row["name"], | |||
| "title_tks": rag_tokenizer.tokenize(row["name"]) | |||
| } | |||
| res = [] | |||
| tk_count = 0 | |||
| for content, vctr in chunks[original_length:]: | |||
| d = copy.deepcopy(doc) | |||
| md5 = hashlib.md5() | |||
| md5.update((content + str(d["doc_id"])).encode("utf-8")) | |||
| d["_id"] = md5.hexdigest() | |||
| d["create_time"] = str(datetime.datetime.now()).replace("T", " ")[:19] | |||
| d["create_timestamp_flt"] = datetime.datetime.now().timestamp() | |||
| d[vctr_nm] = vctr.tolist() | |||
| d["content_with_weight"] = content | |||
| d["content_ltks"] = rag_tokenizer.tokenize(content) | |||
| d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"]) | |||
| res.append(d) | |||
| tk_count += num_tokens_from_string(content) | |||
| return res, tk_count | |||
| def main(): | |||
| rows = collect() | |||
| if len(rows) == 0: | |||
| @@ -259,35 +304,45 @@ def main(): | |||
| cron_logger.error(str(e)) | |||
| continue | |||
| st = timer() | |||
| cks = build(r) | |||
| cron_logger.info("Build chunks({}): {}".format(r["name"], timer() - st)) | |||
| if cks is None: | |||
| continue | |||
| if not cks: | |||
| callback(1., "No chunk! Done!") | |||
| continue | |||
| # TODO: exception handler | |||
| ## set_progress(r["did"], -1, "ERROR: ") | |||
| callback( | |||
| msg="Finished slicing files(%d). Start to embedding the content." % | |||
| len(cks)) | |||
| st = timer() | |||
| try: | |||
| tk_count = embedding(cks, embd_mdl, r["parser_config"], callback) | |||
| except Exception as e: | |||
| callback(-1, "Embedding error:{}".format(str(e))) | |||
| cron_logger.error(str(e)) | |||
| tk_count = 0 | |||
| cron_logger.info("Embedding elapsed({}): {:.2f}".format(r["name"], timer() - st)) | |||
| if r.get("task_type", "") == "raptor": | |||
| try: | |||
| chat_mdl = LLMBundle(r["tenant_id"], LLMType.CHAT, llm_name=r["llm_id"], lang=r["language"]) | |||
| cks, tk_count = run_raptor(r, chat_mdl, embd_mdl, callback) | |||
| except Exception as e: | |||
| callback(-1, msg=str(e)) | |||
| cron_logger.error(str(e)) | |||
| continue | |||
| else: | |||
| st = timer() | |||
| cks = build(r) | |||
| cron_logger.info("Build chunks({}): {}".format(r["name"], timer() - st)) | |||
| if cks is None: | |||
| continue | |||
| if not cks: | |||
| callback(1., "No chunk! Done!") | |||
| continue | |||
| # TODO: exception handler | |||
| ## set_progress(r["did"], -1, "ERROR: ") | |||
| callback( | |||
| msg="Finished slicing files(%d). Start to embedding the content." % | |||
| len(cks)) | |||
| st = timer() | |||
| try: | |||
| tk_count = embedding(cks, embd_mdl, r["parser_config"], callback) | |||
| except Exception as e: | |||
| callback(-1, "Embedding error:{}".format(str(e))) | |||
| cron_logger.error(str(e)) | |||
| tk_count = 0 | |||
| cron_logger.info("Embedding elapsed({}): {:.2f}".format(r["name"], timer() - st)) | |||
| callback(msg="Finished embedding({:.2f})! Start to build index!".format(timer() - st)) | |||
| callback(msg="Finished embedding({:.2f})! Start to build index!".format(timer() - st)) | |||
| init_kb(r) | |||
| chunk_count = len(set([c["_id"] for c in cks])) | |||
| st = timer() | |||
| es_r = "" | |||
| for b in range(0, len(cks), 32): | |||
| es_r = ELASTICSEARCH.bulk(cks[b:b + 32], search.index_name(r["tenant_id"])) | |||
| es_bulk_size = 16 | |||
| for b in range(0, len(cks), es_bulk_size): | |||
| es_r = ELASTICSEARCH.bulk(cks[b:b + es_bulk_size], search.index_name(r["tenant_id"])) | |||
| if b % 128 == 0: | |||
| callback(prog=0.8 + 0.1 * (b + 1) / len(cks), msg="") | |||
| @@ -97,15 +97,17 @@ class RedisDB: | |||
| return False | |||
| def queue_product(self, queue, message, exp=settings.SVR_QUEUE_RETENTION) -> bool: | |||
| try: | |||
| payload = {"message": json.dumps(message)} | |||
| pipeline = self.REDIS.pipeline() | |||
| pipeline.xadd(queue, payload) | |||
| pipeline.expire(queue, exp) | |||
| pipeline.execute() | |||
| return True | |||
| except Exception as e: | |||
| logging.warning("[EXCEPTION]producer" + str(queue) + "||" + str(e)) | |||
| for _ in range(3): | |||
| try: | |||
| payload = {"message": json.dumps(message)} | |||
| pipeline = self.REDIS.pipeline() | |||
| pipeline.xadd(queue, payload) | |||
| pipeline.expire(queue, exp) | |||
| pipeline.execute() | |||
| return True | |||
| except Exception as e: | |||
| print(e) | |||
| logging.warning("[EXCEPTION]producer" + str(queue) + "||" + str(e)) | |||
| return False | |||
| def queue_consumer(self, queue_name, group_name, consumer_name, msg_id=b">") -> Payload: | |||