### What problem does this PR solve? Introduced task priority ### Type of change - [x] New Feature (non-breaking change which adds functionality)tags/v0.18.0
| @@ -479,7 +479,7 @@ def upload(): | |||
| doc = doc.to_dict() | |||
| doc["tenant_id"] = tenant_id | |||
| bucket, name = File2DocumentService.get_storage_address(doc_id=doc["id"]) | |||
| queue_tasks(doc, bucket, name) | |||
| queue_tasks(doc, bucket, name, 0) | |||
| except Exception as e: | |||
| return server_error_response(e) | |||
| @@ -380,7 +380,7 @@ def run(): | |||
| doc = doc.to_dict() | |||
| doc["tenant_id"] = tenant_id | |||
| bucket, name = File2DocumentService.get_storage_address(doc_id=doc["id"]) | |||
| queue_tasks(doc, bucket, name) | |||
| queue_tasks(doc, bucket, name, 0) | |||
| return get_json_result(data=True) | |||
| except Exception as e: | |||
| @@ -693,7 +693,7 @@ def parse(tenant_id, dataset_id): | |||
| doc = doc.to_dict() | |||
| doc["tenant_id"] = tenant_id | |||
| bucket, name = File2DocumentService.get_storage_address(doc_id=doc["id"]) | |||
| queue_tasks(doc, bucket, name) | |||
| queue_tasks(doc, bucket, name, 0) | |||
| return get_result() | |||
| @@ -845,6 +845,7 @@ class Task(DataBaseModel): | |||
| from_page = IntegerField(default=0) | |||
| to_page = IntegerField(default=100000000) | |||
| task_type = CharField(max_length=32, null=False, default="") | |||
| priority = IntegerField(default=0) | |||
| begin_at = DateTimeField(null=True, index=True) | |||
| process_duation = FloatField(default=0) | |||
| @@ -1122,3 +1123,10 @@ def migrate_db(): | |||
| ) | |||
| except Exception: | |||
| pass | |||
| try: | |||
| migrate( | |||
| migrator.add_column("task", "priority", | |||
| IntegerField(default=0)) | |||
| ) | |||
| except Exception: | |||
| pass | |||
| @@ -34,7 +34,7 @@ from api.db.services.common_service import CommonService | |||
| from api.db.services.knowledgebase_service import KnowledgebaseService | |||
| from api.utils import current_timestamp, get_format_time, get_uuid | |||
| from rag.nlp import rag_tokenizer, search | |||
| from rag.settings import SVR_QUEUE_NAME | |||
| from rag.settings import get_svr_queue_name | |||
| from rag.utils.redis_conn import REDIS_CONN | |||
| from rag.utils.storage_factory import STORAGE_IMPL | |||
| @@ -392,6 +392,7 @@ class DocumentService(CommonService): | |||
| has_graphrag = False | |||
| e, doc = DocumentService.get_by_id(d["id"]) | |||
| status = doc.run # TaskStatus.RUNNING.value | |||
| priority = 0 | |||
| for t in tsks: | |||
| if 0 <= t.progress < 1: | |||
| finished = False | |||
| @@ -403,16 +404,17 @@ class DocumentService(CommonService): | |||
| has_raptor = True | |||
| elif t.task_type == "graphrag": | |||
| has_graphrag = True | |||
| priority = max(priority, t.priority) | |||
| prg /= len(tsks) | |||
| if finished and bad: | |||
| prg = -1 | |||
| status = TaskStatus.FAIL.value | |||
| elif finished: | |||
| if d["parser_config"].get("raptor", {}).get("use_raptor") and not has_raptor: | |||
| queue_raptor_o_graphrag_tasks(d, "raptor") | |||
| queue_raptor_o_graphrag_tasks(d, "raptor", priority) | |||
| prg = 0.98 * len(tsks) / (len(tsks) + 1) | |||
| elif d["parser_config"].get("graphrag", {}).get("use_graphrag") and not has_graphrag: | |||
| queue_raptor_o_graphrag_tasks(d, "graphrag") | |||
| queue_raptor_o_graphrag_tasks(d, "graphrag", priority) | |||
| prg = 0.98 * len(tsks) / (len(tsks) + 1) | |||
| else: | |||
| status = TaskStatus.DONE.value | |||
| @@ -449,7 +451,7 @@ class DocumentService(CommonService): | |||
| return False | |||
| def queue_raptor_o_graphrag_tasks(doc, ty): | |||
| def queue_raptor_o_graphrag_tasks(doc, ty, priority): | |||
| chunking_config = DocumentService.get_chunking_config(doc["id"]) | |||
| hasher = xxhash.xxh64() | |||
| for field in sorted(chunking_config.keys()): | |||
| @@ -472,7 +474,7 @@ def queue_raptor_o_graphrag_tasks(doc, ty): | |||
| hasher.update(ty.encode("utf-8")) | |||
| task["digest"] = hasher.hexdigest() | |||
| bulk_insert_into_db(Task, [task], True) | |||
| assert REDIS_CONN.queue_product(SVR_QUEUE_NAME, message=task), "Can't access Redis. Please check the Redis' status." | |||
| assert REDIS_CONN.queue_product(get_svr_queue_name(priority), message=task), "Can't access Redis. Please check the Redis' status." | |||
| def doc_upload_and_parse(conversation_id, file_objs, user_id): | |||
| @@ -28,7 +28,7 @@ from api.db.services.common_service import CommonService | |||
| from api.db.services.document_service import DocumentService | |||
| from api.utils import current_timestamp, get_uuid | |||
| from deepdoc.parser.excel_parser import RAGFlowExcelParser | |||
| from rag.settings import SVR_QUEUE_NAME | |||
| from rag.settings import get_svr_queue_name | |||
| from rag.utils.storage_factory import STORAGE_IMPL | |||
| from rag.utils.redis_conn import REDIS_CONN | |||
| from api import settings | |||
| @@ -289,7 +289,7 @@ class TaskService(CommonService): | |||
| ).execute() | |||
| def queue_tasks(doc: dict, bucket: str, name: str): | |||
| def queue_tasks(doc: dict, bucket: str, name: str, priority: int): | |||
| """Create and queue document processing tasks. | |||
| This function creates processing tasks for a document based on its type and configuration. | |||
| @@ -301,6 +301,7 @@ def queue_tasks(doc: dict, bucket: str, name: str): | |||
| doc (dict): Document dictionary containing metadata and configuration. | |||
| bucket (str): Storage bucket name where the document is stored. | |||
| name (str): File name of the document. | |||
| priority (int, optional): Priority level for task queueing (default is 0). | |||
| Note: | |||
| - For PDF documents, tasks are created per page range based on configuration | |||
| @@ -358,6 +359,7 @@ def queue_tasks(doc: dict, bucket: str, name: str): | |||
| task_digest = hasher.hexdigest() | |||
| task["digest"] = task_digest | |||
| task["progress"] = 0.0 | |||
| task["priority"] = priority | |||
| prev_tasks = TaskService.get_tasks(doc["id"]) | |||
| ck_num = 0 | |||
| @@ -380,7 +382,7 @@ def queue_tasks(doc: dict, bucket: str, name: str): | |||
| unfinished_task_array = [task for task in parse_task_array if task["progress"] < 1.0] | |||
| for unfinished_task in unfinished_task_array: | |||
| assert REDIS_CONN.queue_product( | |||
| SVR_QUEUE_NAME, message=unfinished_task | |||
| get_svr_queue_name(priority), message=unfinished_task | |||
| ), "Can't access Redis. Please check the Redis' status." | |||
| @@ -35,16 +35,20 @@ except Exception: | |||
| DOC_MAXIMUM_SIZE = int(os.environ.get("MAX_CONTENT_LENGTH", 128 * 1024 * 1024)) | |||
| SVR_QUEUE_NAME = "rag_flow_svr_queue" | |||
| SVR_QUEUE_RETENTION = 60*60 | |||
| SVR_QUEUE_MAX_LEN = 1024 | |||
| SVR_CONSUMER_NAME = "rag_flow_svr_consumer" | |||
| SVR_CONSUMER_GROUP_NAME = "rag_flow_svr_consumer_group" | |||
| SVR_CONSUMER_GROUP_NAME = "rag_flow_svr_task_broker" | |||
| PAGERANK_FLD = "pagerank_fea" | |||
| TAG_FLD = "tag_feas" | |||
| def print_rag_settings(): | |||
| logging.info(f"MAX_CONTENT_LENGTH: {DOC_MAXIMUM_SIZE}") | |||
| logging.info(f"SERVER_QUEUE_MAX_LEN: {SVR_QUEUE_MAX_LEN}") | |||
| logging.info(f"SERVER_QUEUE_RETENTION: {SVR_QUEUE_RETENTION}") | |||
| logging.info(f"MAX_FILE_COUNT_PER_USER: {int(os.environ.get('MAX_FILE_NUM_PER_USER', 0))}") | |||
| logging.info(f"MAX_FILE_COUNT_PER_USER: {int(os.environ.get('MAX_FILE_NUM_PER_USER', 0))}") | |||
| def get_svr_queue_name(priority: int) -> str: | |||
| if priority == 0: | |||
| return SVR_QUEUE_NAME | |||
| return f"{SVR_QUEUE_NAME}_{priority}" | |||
| def get_svr_queue_names(): | |||
| return [get_svr_queue_name(priority) for priority in [1, 0]] | |||
| @@ -56,7 +56,7 @@ from rag.app import laws, paper, presentation, manual, qa, table, book, resume, | |||
| email, tag | |||
| from rag.nlp import search, rag_tokenizer | |||
| from rag.raptor import RecursiveAbstractiveProcessing4TreeOrganizedRetrieval as Raptor | |||
| from rag.settings import DOC_MAXIMUM_SIZE, SVR_QUEUE_NAME, print_rag_settings, TAG_FLD, PAGERANK_FLD | |||
| from rag.settings import DOC_MAXIMUM_SIZE, SVR_CONSUMER_GROUP_NAME, get_svr_queue_name, get_svr_queue_names, print_rag_settings, TAG_FLD, PAGERANK_FLD | |||
| from rag.utils import num_tokens_from_string | |||
| from rag.utils.redis_conn import REDIS_CONN | |||
| from rag.utils.storage_factory import STORAGE_IMPL | |||
| @@ -171,20 +171,23 @@ def set_progress(task_id, from_page=0, to_page=-1, prog=None, msg="Processing... | |||
| async def collect(): | |||
| global CONSUMER_NAME, DONE_TASKS, FAILED_TASKS | |||
| global UNACKED_ITERATOR | |||
| svr_queue_names = get_svr_queue_names() | |||
| try: | |||
| if not UNACKED_ITERATOR: | |||
| UNACKED_ITERATOR = REDIS_CONN.get_unacked_iterator(SVR_QUEUE_NAME, "rag_flow_svr_task_broker", CONSUMER_NAME) | |||
| UNACKED_ITERATOR = REDIS_CONN.get_unacked_iterator(svr_queue_names, SVR_CONSUMER_GROUP_NAME, CONSUMER_NAME) | |||
| try: | |||
| redis_msg = next(UNACKED_ITERATOR) | |||
| except StopIteration: | |||
| redis_msg = REDIS_CONN.queue_consumer(SVR_QUEUE_NAME, "rag_flow_svr_task_broker", CONSUMER_NAME) | |||
| if not redis_msg: | |||
| await trio.sleep(1) | |||
| return None, None | |||
| for svr_queue_name in svr_queue_names: | |||
| redis_msg = REDIS_CONN.queue_consumer(svr_queue_name, SVR_CONSUMER_GROUP_NAME, CONSUMER_NAME) | |||
| if redis_msg: | |||
| break | |||
| except Exception: | |||
| logging.exception("collect got exception") | |||
| return None, None | |||
| if not redis_msg: | |||
| return None, None | |||
| msg = redis_msg.get_message() | |||
| if not msg: | |||
| logging.error(f"collect got empty message of {redis_msg.get_msg_id()}") | |||
| @@ -615,7 +618,7 @@ async def report_status(): | |||
| while True: | |||
| try: | |||
| now = datetime.now() | |||
| group_info = REDIS_CONN.queue_info(SVR_QUEUE_NAME, "rag_flow_svr_task_broker") | |||
| group_info = REDIS_CONN.queue_info(get_svr_queue_name(0), SVR_CONSUMER_GROUP_NAME) | |||
| if group_info is not None: | |||
| PENDING_TASKS = int(group_info.get("pending", 0)) | |||
| LAG_TASKS = int(group_info.get("lag", 0)) | |||
| @@ -193,14 +193,11 @@ class RedisDB: | |||
| self.__open__() | |||
| return False | |||
| def queue_product(self, queue, message, exp=settings.SVR_QUEUE_RETENTION) -> bool: | |||
| def queue_product(self, queue, message) -> bool: | |||
| for _ in range(3): | |||
| try: | |||
| payload = {"message": json.dumps(message)} | |||
| pipeline = self.REDIS.pipeline() | |||
| pipeline.xadd(queue, payload) | |||
| # pipeline.expire(queue, exp) | |||
| pipeline.execute() | |||
| self.REDIS.xadd(queue, payload) | |||
| return True | |||
| except Exception as e: | |||
| logging.exception( | |||
| @@ -242,19 +239,20 @@ class RedisDB: | |||
| ) | |||
| return None | |||
| def get_unacked_iterator(self, queue_name, group_name, consumer_name): | |||
| def get_unacked_iterator(self, queue_names: list[str], group_name, consumer_name): | |||
| try: | |||
| group_info = self.REDIS.xinfo_groups(queue_name) | |||
| if not any(e["name"] == group_name for e in group_info): | |||
| return | |||
| current_min = 0 | |||
| while True: | |||
| payload = self.queue_consumer(queue_name, group_name, consumer_name, current_min) | |||
| if not payload: | |||
| return | |||
| current_min = payload.get_msg_id() | |||
| logging.info(f"RedisDB.get_unacked_iterator {consumer_name} msg_id {current_min}") | |||
| yield payload | |||
| for queue_name in queue_names: | |||
| group_info = self.REDIS.xinfo_groups(queue_name) | |||
| if not any(e["name"] == group_name for e in group_info): | |||
| continue | |||
| current_min = 0 | |||
| while True: | |||
| payload = self.queue_consumer(queue_name, group_name, consumer_name, current_min) | |||
| if not payload: | |||
| break | |||
| current_min = payload.get_msg_id() | |||
| logging.info(f"RedisDB.get_unacked_iterator {consumer_name} msg_id {current_min}") | |||
| yield payload | |||
| except Exception as e: | |||
| if "key" in str(e): | |||
| return | |||