### What problem does this PR solve? Introduced task priority ### Type of change - [x] New Feature (non-breaking change which adds functionality)tags/v0.18.0
| doc = doc.to_dict() | doc = doc.to_dict() | ||||
| doc["tenant_id"] = tenant_id | doc["tenant_id"] = tenant_id | ||||
| bucket, name = File2DocumentService.get_storage_address(doc_id=doc["id"]) | bucket, name = File2DocumentService.get_storage_address(doc_id=doc["id"]) | ||||
| queue_tasks(doc, bucket, name) | |||||
| queue_tasks(doc, bucket, name, 0) | |||||
| except Exception as e: | except Exception as e: | ||||
| return server_error_response(e) | return server_error_response(e) | ||||
| doc = doc.to_dict() | doc = doc.to_dict() | ||||
| doc["tenant_id"] = tenant_id | doc["tenant_id"] = tenant_id | ||||
| bucket, name = File2DocumentService.get_storage_address(doc_id=doc["id"]) | bucket, name = File2DocumentService.get_storage_address(doc_id=doc["id"]) | ||||
| queue_tasks(doc, bucket, name) | |||||
| queue_tasks(doc, bucket, name, 0) | |||||
| return get_json_result(data=True) | return get_json_result(data=True) | ||||
| except Exception as e: | except Exception as e: |
| doc = doc.to_dict() | doc = doc.to_dict() | ||||
| doc["tenant_id"] = tenant_id | doc["tenant_id"] = tenant_id | ||||
| bucket, name = File2DocumentService.get_storage_address(doc_id=doc["id"]) | bucket, name = File2DocumentService.get_storage_address(doc_id=doc["id"]) | ||||
| queue_tasks(doc, bucket, name) | |||||
| queue_tasks(doc, bucket, name, 0) | |||||
| return get_result() | return get_result() | ||||
| from_page = IntegerField(default=0) | from_page = IntegerField(default=0) | ||||
| to_page = IntegerField(default=100000000) | to_page = IntegerField(default=100000000) | ||||
| task_type = CharField(max_length=32, null=False, default="") | task_type = CharField(max_length=32, null=False, default="") | ||||
| priority = IntegerField(default=0) | |||||
| begin_at = DateTimeField(null=True, index=True) | begin_at = DateTimeField(null=True, index=True) | ||||
| process_duation = FloatField(default=0) | process_duation = FloatField(default=0) | ||||
| ) | ) | ||||
| except Exception: | except Exception: | ||||
| pass | pass | ||||
| try: | |||||
| migrate( | |||||
| migrator.add_column("task", "priority", | |||||
| IntegerField(default=0)) | |||||
| ) | |||||
| except Exception: | |||||
| pass |
| from api.db.services.knowledgebase_service import KnowledgebaseService | from api.db.services.knowledgebase_service import KnowledgebaseService | ||||
| from api.utils import current_timestamp, get_format_time, get_uuid | from api.utils import current_timestamp, get_format_time, get_uuid | ||||
| from rag.nlp import rag_tokenizer, search | from rag.nlp import rag_tokenizer, search | ||||
| from rag.settings import SVR_QUEUE_NAME | |||||
| from rag.settings import get_svr_queue_name | |||||
| from rag.utils.redis_conn import REDIS_CONN | from rag.utils.redis_conn import REDIS_CONN | ||||
| from rag.utils.storage_factory import STORAGE_IMPL | from rag.utils.storage_factory import STORAGE_IMPL | ||||
| has_graphrag = False | has_graphrag = False | ||||
| e, doc = DocumentService.get_by_id(d["id"]) | e, doc = DocumentService.get_by_id(d["id"]) | ||||
| status = doc.run # TaskStatus.RUNNING.value | status = doc.run # TaskStatus.RUNNING.value | ||||
| priority = 0 | |||||
| for t in tsks: | for t in tsks: | ||||
| if 0 <= t.progress < 1: | if 0 <= t.progress < 1: | ||||
| finished = False | finished = False | ||||
| has_raptor = True | has_raptor = True | ||||
| elif t.task_type == "graphrag": | elif t.task_type == "graphrag": | ||||
| has_graphrag = True | has_graphrag = True | ||||
| priority = max(priority, t.priority) | |||||
| prg /= len(tsks) | prg /= len(tsks) | ||||
| if finished and bad: | if finished and bad: | ||||
| prg = -1 | prg = -1 | ||||
| status = TaskStatus.FAIL.value | status = TaskStatus.FAIL.value | ||||
| elif finished: | elif finished: | ||||
| if d["parser_config"].get("raptor", {}).get("use_raptor") and not has_raptor: | if d["parser_config"].get("raptor", {}).get("use_raptor") and not has_raptor: | ||||
| queue_raptor_o_graphrag_tasks(d, "raptor") | |||||
| queue_raptor_o_graphrag_tasks(d, "raptor", priority) | |||||
| prg = 0.98 * len(tsks) / (len(tsks) + 1) | prg = 0.98 * len(tsks) / (len(tsks) + 1) | ||||
| elif d["parser_config"].get("graphrag", {}).get("use_graphrag") and not has_graphrag: | elif d["parser_config"].get("graphrag", {}).get("use_graphrag") and not has_graphrag: | ||||
| queue_raptor_o_graphrag_tasks(d, "graphrag") | |||||
| queue_raptor_o_graphrag_tasks(d, "graphrag", priority) | |||||
| prg = 0.98 * len(tsks) / (len(tsks) + 1) | prg = 0.98 * len(tsks) / (len(tsks) + 1) | ||||
| else: | else: | ||||
| status = TaskStatus.DONE.value | status = TaskStatus.DONE.value | ||||
| return False | return False | ||||
| def queue_raptor_o_graphrag_tasks(doc, ty): | |||||
| def queue_raptor_o_graphrag_tasks(doc, ty, priority): | |||||
| chunking_config = DocumentService.get_chunking_config(doc["id"]) | chunking_config = DocumentService.get_chunking_config(doc["id"]) | ||||
| hasher = xxhash.xxh64() | hasher = xxhash.xxh64() | ||||
| for field in sorted(chunking_config.keys()): | for field in sorted(chunking_config.keys()): | ||||
| hasher.update(ty.encode("utf-8")) | hasher.update(ty.encode("utf-8")) | ||||
| task["digest"] = hasher.hexdigest() | task["digest"] = hasher.hexdigest() | ||||
| bulk_insert_into_db(Task, [task], True) | bulk_insert_into_db(Task, [task], True) | ||||
| assert REDIS_CONN.queue_product(SVR_QUEUE_NAME, message=task), "Can't access Redis. Please check the Redis' status." | |||||
| assert REDIS_CONN.queue_product(get_svr_queue_name(priority), message=task), "Can't access Redis. Please check the Redis' status." | |||||
| def doc_upload_and_parse(conversation_id, file_objs, user_id): | def doc_upload_and_parse(conversation_id, file_objs, user_id): |
| from api.db.services.document_service import DocumentService | from api.db.services.document_service import DocumentService | ||||
| from api.utils import current_timestamp, get_uuid | from api.utils import current_timestamp, get_uuid | ||||
| from deepdoc.parser.excel_parser import RAGFlowExcelParser | from deepdoc.parser.excel_parser import RAGFlowExcelParser | ||||
| from rag.settings import SVR_QUEUE_NAME | |||||
| from rag.settings import get_svr_queue_name | |||||
| from rag.utils.storage_factory import STORAGE_IMPL | from rag.utils.storage_factory import STORAGE_IMPL | ||||
| from rag.utils.redis_conn import REDIS_CONN | from rag.utils.redis_conn import REDIS_CONN | ||||
| from api import settings | from api import settings | ||||
| ).execute() | ).execute() | ||||
| def queue_tasks(doc: dict, bucket: str, name: str): | |||||
| def queue_tasks(doc: dict, bucket: str, name: str, priority: int): | |||||
| """Create and queue document processing tasks. | """Create and queue document processing tasks. | ||||
| This function creates processing tasks for a document based on its type and configuration. | This function creates processing tasks for a document based on its type and configuration. | ||||
| doc (dict): Document dictionary containing metadata and configuration. | doc (dict): Document dictionary containing metadata and configuration. | ||||
| bucket (str): Storage bucket name where the document is stored. | bucket (str): Storage bucket name where the document is stored. | ||||
| name (str): File name of the document. | name (str): File name of the document. | ||||
| priority (int, optional): Priority level for task queueing (default is 0). | |||||
| Note: | Note: | ||||
| - For PDF documents, tasks are created per page range based on configuration | - For PDF documents, tasks are created per page range based on configuration | ||||
| task_digest = hasher.hexdigest() | task_digest = hasher.hexdigest() | ||||
| task["digest"] = task_digest | task["digest"] = task_digest | ||||
| task["progress"] = 0.0 | task["progress"] = 0.0 | ||||
| task["priority"] = priority | |||||
| prev_tasks = TaskService.get_tasks(doc["id"]) | prev_tasks = TaskService.get_tasks(doc["id"]) | ||||
| ck_num = 0 | ck_num = 0 | ||||
| unfinished_task_array = [task for task in parse_task_array if task["progress"] < 1.0] | unfinished_task_array = [task for task in parse_task_array if task["progress"] < 1.0] | ||||
| for unfinished_task in unfinished_task_array: | for unfinished_task in unfinished_task_array: | ||||
| assert REDIS_CONN.queue_product( | assert REDIS_CONN.queue_product( | ||||
| SVR_QUEUE_NAME, message=unfinished_task | |||||
| get_svr_queue_name(priority), message=unfinished_task | |||||
| ), "Can't access Redis. Please check the Redis' status." | ), "Can't access Redis. Please check the Redis' status." | ||||
| DOC_MAXIMUM_SIZE = int(os.environ.get("MAX_CONTENT_LENGTH", 128 * 1024 * 1024)) | DOC_MAXIMUM_SIZE = int(os.environ.get("MAX_CONTENT_LENGTH", 128 * 1024 * 1024)) | ||||
| SVR_QUEUE_NAME = "rag_flow_svr_queue" | SVR_QUEUE_NAME = "rag_flow_svr_queue" | ||||
| SVR_QUEUE_RETENTION = 60*60 | |||||
| SVR_QUEUE_MAX_LEN = 1024 | |||||
| SVR_CONSUMER_NAME = "rag_flow_svr_consumer" | |||||
| SVR_CONSUMER_GROUP_NAME = "rag_flow_svr_consumer_group" | |||||
| SVR_CONSUMER_GROUP_NAME = "rag_flow_svr_task_broker" | |||||
| PAGERANK_FLD = "pagerank_fea" | PAGERANK_FLD = "pagerank_fea" | ||||
| TAG_FLD = "tag_feas" | TAG_FLD = "tag_feas" | ||||
| def print_rag_settings(): | def print_rag_settings(): | ||||
| logging.info(f"MAX_CONTENT_LENGTH: {DOC_MAXIMUM_SIZE}") | logging.info(f"MAX_CONTENT_LENGTH: {DOC_MAXIMUM_SIZE}") | ||||
| logging.info(f"SERVER_QUEUE_MAX_LEN: {SVR_QUEUE_MAX_LEN}") | |||||
| logging.info(f"SERVER_QUEUE_RETENTION: {SVR_QUEUE_RETENTION}") | |||||
| logging.info(f"MAX_FILE_COUNT_PER_USER: {int(os.environ.get('MAX_FILE_NUM_PER_USER', 0))}") | |||||
| logging.info(f"MAX_FILE_COUNT_PER_USER: {int(os.environ.get('MAX_FILE_NUM_PER_USER', 0))}") | |||||
| def get_svr_queue_name(priority: int) -> str: | |||||
| if priority == 0: | |||||
| return SVR_QUEUE_NAME | |||||
| return f"{SVR_QUEUE_NAME}_{priority}" | |||||
| def get_svr_queue_names(): | |||||
| return [get_svr_queue_name(priority) for priority in [1, 0]] |
| email, tag | email, tag | ||||
| from rag.nlp import search, rag_tokenizer | from rag.nlp import search, rag_tokenizer | ||||
| from rag.raptor import RecursiveAbstractiveProcessing4TreeOrganizedRetrieval as Raptor | from rag.raptor import RecursiveAbstractiveProcessing4TreeOrganizedRetrieval as Raptor | ||||
| from rag.settings import DOC_MAXIMUM_SIZE, SVR_QUEUE_NAME, print_rag_settings, TAG_FLD, PAGERANK_FLD | |||||
| from rag.settings import DOC_MAXIMUM_SIZE, SVR_CONSUMER_GROUP_NAME, get_svr_queue_name, get_svr_queue_names, print_rag_settings, TAG_FLD, PAGERANK_FLD | |||||
| from rag.utils import num_tokens_from_string | from rag.utils import num_tokens_from_string | ||||
| from rag.utils.redis_conn import REDIS_CONN | from rag.utils.redis_conn import REDIS_CONN | ||||
| from rag.utils.storage_factory import STORAGE_IMPL | from rag.utils.storage_factory import STORAGE_IMPL | ||||
| async def collect(): | async def collect(): | ||||
| global CONSUMER_NAME, DONE_TASKS, FAILED_TASKS | global CONSUMER_NAME, DONE_TASKS, FAILED_TASKS | ||||
| global UNACKED_ITERATOR | global UNACKED_ITERATOR | ||||
| svr_queue_names = get_svr_queue_names() | |||||
| try: | try: | ||||
| if not UNACKED_ITERATOR: | if not UNACKED_ITERATOR: | ||||
| UNACKED_ITERATOR = REDIS_CONN.get_unacked_iterator(SVR_QUEUE_NAME, "rag_flow_svr_task_broker", CONSUMER_NAME) | |||||
| UNACKED_ITERATOR = REDIS_CONN.get_unacked_iterator(svr_queue_names, SVR_CONSUMER_GROUP_NAME, CONSUMER_NAME) | |||||
| try: | try: | ||||
| redis_msg = next(UNACKED_ITERATOR) | redis_msg = next(UNACKED_ITERATOR) | ||||
| except StopIteration: | except StopIteration: | ||||
| redis_msg = REDIS_CONN.queue_consumer(SVR_QUEUE_NAME, "rag_flow_svr_task_broker", CONSUMER_NAME) | |||||
| if not redis_msg: | |||||
| await trio.sleep(1) | |||||
| return None, None | |||||
| for svr_queue_name in svr_queue_names: | |||||
| redis_msg = REDIS_CONN.queue_consumer(svr_queue_name, SVR_CONSUMER_GROUP_NAME, CONSUMER_NAME) | |||||
| if redis_msg: | |||||
| break | |||||
| except Exception: | except Exception: | ||||
| logging.exception("collect got exception") | logging.exception("collect got exception") | ||||
| return None, None | return None, None | ||||
| if not redis_msg: | |||||
| return None, None | |||||
| msg = redis_msg.get_message() | msg = redis_msg.get_message() | ||||
| if not msg: | if not msg: | ||||
| logging.error(f"collect got empty message of {redis_msg.get_msg_id()}") | logging.error(f"collect got empty message of {redis_msg.get_msg_id()}") | ||||
| while True: | while True: | ||||
| try: | try: | ||||
| now = datetime.now() | now = datetime.now() | ||||
| group_info = REDIS_CONN.queue_info(SVR_QUEUE_NAME, "rag_flow_svr_task_broker") | |||||
| group_info = REDIS_CONN.queue_info(get_svr_queue_name(0), SVR_CONSUMER_GROUP_NAME) | |||||
| if group_info is not None: | if group_info is not None: | ||||
| PENDING_TASKS = int(group_info.get("pending", 0)) | PENDING_TASKS = int(group_info.get("pending", 0)) | ||||
| LAG_TASKS = int(group_info.get("lag", 0)) | LAG_TASKS = int(group_info.get("lag", 0)) |
| self.__open__() | self.__open__() | ||||
| return False | return False | ||||
| def queue_product(self, queue, message, exp=settings.SVR_QUEUE_RETENTION) -> bool: | |||||
| def queue_product(self, queue, message) -> bool: | |||||
| for _ in range(3): | for _ in range(3): | ||||
| try: | try: | ||||
| payload = {"message": json.dumps(message)} | payload = {"message": json.dumps(message)} | ||||
| pipeline = self.REDIS.pipeline() | |||||
| pipeline.xadd(queue, payload) | |||||
| # pipeline.expire(queue, exp) | |||||
| pipeline.execute() | |||||
| self.REDIS.xadd(queue, payload) | |||||
| return True | return True | ||||
| except Exception as e: | except Exception as e: | ||||
| logging.exception( | logging.exception( | ||||
| ) | ) | ||||
| return None | return None | ||||
| def get_unacked_iterator(self, queue_name, group_name, consumer_name): | |||||
| def get_unacked_iterator(self, queue_names: list[str], group_name, consumer_name): | |||||
| try: | try: | ||||
| group_info = self.REDIS.xinfo_groups(queue_name) | |||||
| if not any(e["name"] == group_name for e in group_info): | |||||
| return | |||||
| current_min = 0 | |||||
| while True: | |||||
| payload = self.queue_consumer(queue_name, group_name, consumer_name, current_min) | |||||
| if not payload: | |||||
| return | |||||
| current_min = payload.get_msg_id() | |||||
| logging.info(f"RedisDB.get_unacked_iterator {consumer_name} msg_id {current_min}") | |||||
| yield payload | |||||
| for queue_name in queue_names: | |||||
| group_info = self.REDIS.xinfo_groups(queue_name) | |||||
| if not any(e["name"] == group_name for e in group_info): | |||||
| continue | |||||
| current_min = 0 | |||||
| while True: | |||||
| payload = self.queue_consumer(queue_name, group_name, consumer_name, current_min) | |||||
| if not payload: | |||||
| break | |||||
| current_min = payload.get_msg_id() | |||||
| logging.info(f"RedisDB.get_unacked_iterator {consumer_name} msg_id {current_min}") | |||||
| yield payload | |||||
| except Exception as e: | except Exception as e: | ||||
| if "key" in str(e): | if "key" in str(e): | ||||
| return | return |