Browse Source

handle nits in task_executor (#2637)

### What problem does this PR solve?

- fix typo
- fix string format
- format import

### Type of change

- [x] Refactoring
tags/v0.12.0
yqkcn 1 year ago
parent
commit
a44ed9626a
No account linked to committer's email address
1 changed files with 36 additions and 38 deletions
  1. 36
    38
      rag/svr/task_executor.py

+ 36
- 38
rag/svr/task_executor.py View File

import traceback import traceback
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from functools import partial from functools import partial

from api.db.services.file2document_service import File2DocumentService
from api.settings import retrievaler
from rag.raptor import RecursiveAbstractiveProcessing4TreeOrganizedRetrieval as Raptor
from rag.utils.storage_factory import STORAGE_IMPL
from api.db.db_models import close_connection
from rag.settings import database_logger, SVR_QUEUE_NAME
from rag.settings import cron_logger, DOC_MAXIMUM_SIZE
from multiprocessing import Pool
import numpy as np
from elasticsearch_dsl import Q, Search
from io import BytesIO
from multiprocessing.context import TimeoutError from multiprocessing.context import TimeoutError
from api.db.services.task_service import TaskService
from rag.utils.es_conn import ELASTICSEARCH
from timeit import default_timer as timer from timeit import default_timer as timer
from rag.utils import rmSpace, findMaxTm, num_tokens_from_string


from rag.nlp import search, rag_tokenizer
from io import BytesIO
import numpy as np
import pandas as pd import pandas as pd

from rag.app import laws, paper, presentation, manual, qa, table, book, resume, picture, naive, one, audio, knowledge_graph, email
from elasticsearch_dsl import Q


from api.db import LLMType, ParserType from api.db import LLMType, ParserType
from api.db.services.document_service import DocumentService from api.db.services.document_service import DocumentService
from api.db.services.llm_service import LLMBundle from api.db.services.llm_service import LLMBundle
from api.db.services.task_service import TaskService
from api.db.services.file2document_service import File2DocumentService
from api.settings import retrievaler
from api.utils.file_utils import get_project_base_directory from api.utils.file_utils import get_project_base_directory
from rag.utils.redis_conn import REDIS_CONN
from api.db.db_models import close_connection
from rag.app import laws, paper, presentation, manual, qa, table, book, resume, picture, naive, one, audio, knowledge_graph, email
from rag.nlp import search, rag_tokenizer
from rag.raptor import RecursiveAbstractiveProcessing4TreeOrganizedRetrieval as Raptor
from rag.settings import database_logger, SVR_QUEUE_NAME
from rag.settings import cron_logger, DOC_MAXIMUM_SIZE
from rag.utils import rmSpace, num_tokens_from_string
from rag.utils.es_conn import ELASTICSEARCH
from rag.utils.redis_conn import REDIS_CONN, Payload
from rag.utils.storage_factory import STORAGE_IMPL


BATCH_SIZE = 64 BATCH_SIZE = 64


ParserType.KG.value: knowledge_graph ParserType.KG.value: knowledge_graph
} }


CONSUMEER_NAME = "task_consumer_" + ("0" if len(sys.argv) < 2 else sys.argv[1])
PAYLOAD = None
CONSUMER_NAME = "task_consumer_" + ("0" if len(sys.argv) < 2 else sys.argv[1])
PAYLOAD: Payload | None = None



def set_progress(task_id, from_page=0, to_page=-1,
prog=None, msg="Processing..."):
def set_progress(task_id, from_page=0, to_page=-1, prog=None, msg="Processing..."):
global PAYLOAD global PAYLOAD
if prog is not None and prog < 0: if prog is not None and prog < 0:
msg = "[ERROR]" + msg msg = "[ERROR]" + msg




def collect(): def collect():
global CONSUMEER_NAME, PAYLOAD
global CONSUMER_NAME, PAYLOAD
try: try:
PAYLOAD = REDIS_CONN.get_unacked_for(CONSUMEER_NAME, SVR_QUEUE_NAME, "rag_flow_svr_task_broker")
PAYLOAD = REDIS_CONN.get_unacked_for(CONSUMER_NAME, SVR_QUEUE_NAME, "rag_flow_svr_task_broker")
if not PAYLOAD: if not PAYLOAD:
PAYLOAD = REDIS_CONN.queue_consumer(SVR_QUEUE_NAME, "rag_flow_svr_task_broker", CONSUMEER_NAME)
PAYLOAD = REDIS_CONN.queue_consumer(SVR_QUEUE_NAME, "rag_flow_svr_task_broker", CONSUMER_NAME)
if not PAYLOAD: if not PAYLOAD:
time.sleep(1) time.sleep(1)
return pd.DataFrame() return pd.DataFrame()
binary = get_storage_binary(bucket, name) binary = get_storage_binary(bucket, name)
cron_logger.info( cron_logger.info(
"From minio({}) {}/{}".format(timer() - st, row["location"], row["name"])) "From minio({}) {}/{}".format(timer() - st, row["location"], row["name"]))
except TimeoutError as e:
callback(-1, f"Internal server error: Fetch file from minio timeout. Could you try it again.")
except TimeoutError:
callback(-1, "Internal server error: Fetch file from minio timeout. Could you try it again.")
cron_logger.error( cron_logger.error(
"Minio {}/{}: Fetch file from minio timeout.".format(row["location"], row["name"])) "Minio {}/{}: Fetch file from minio timeout.".format(row["location"], row["name"]))
return return
if re.search("(No such file|not found)", str(e)): if re.search("(No such file|not found)", str(e)):
callback(-1, "Can not find file <%s> from minio. Could you try it again?" % row["name"]) callback(-1, "Can not find file <%s> from minio. Could you try it again?" % row["name"])
else: else:
callback(-1, f"Get file from minio: %s" %
str(e).replace("'", ""))
callback(-1, "Get file from minio: %s" % str(e).replace("'", ""))
traceback.print_exc() traceback.print_exc()
return return


cron_logger.info( cron_logger.info(
"Chunking({}) {}/{}".format(timer() - st, row["location"], row["name"])) "Chunking({}) {}/{}".format(timer() - st, row["location"], row["name"]))
except Exception as e: except Exception as e:
callback(-1, f"Internal server error while chunking: %s" %
callback(-1, "Internal server error while chunking: %s" %
str(e).replace("'", "")) str(e).replace("'", ""))
cron_logger.error( cron_logger.error(
"Chunking {}/{}: {}".format(row["location"], row["name"], str(e))) "Chunking {}/{}: {}".format(row["location"], row["name"], str(e)))
open(os.path.join(get_project_base_directory(), "conf", "mapping.json"), "r"))) open(os.path.join(get_project_base_directory(), "conf", "mapping.json"), "r")))




def embedding(docs, mdl, parser_config={}, callback=None):
def embedding(docs, mdl, parser_config=None, callback=None):
if parser_config is None:
parser_config = {}
batch_size = 32 batch_size = 32
tts, cnts = [rmSpace(d["title_tks"]) for d in docs if d.get("title_tks")], [ tts, cnts = [rmSpace(d["title_tks"]) for d in docs if d.get("title_tks")], [
re.sub(r"</?(table|td|caption|tr|th)( [^<>]{0,12})?>", " ", d["content_with_weight"]) for d in docs] re.sub(r"</?(table|td|caption|tr|th)( [^<>]{0,12})?>", " ", d["content_with_weight"]) for d in docs]


def run_raptor(row, chat_mdl, embd_mdl, callback=None): def run_raptor(row, chat_mdl, embd_mdl, callback=None):
vts, _ = embd_mdl.encode(["ok"]) vts, _ = embd_mdl.encode(["ok"])
vctr_nm = "q_%d_vec"%len(vts[0])
vctr_nm = "q_%d_vec" % len(vts[0])
chunks = [] chunks = []
for d in retrievaler.chunk_list(row["doc_id"], row["tenant_id"], fields=["content_with_weight", vctr_nm]): for d in retrievaler.chunk_list(row["doc_id"], row["tenant_id"], fields=["content_with_weight", vctr_nm]):
chunks.append((d["content_with_weight"], np.array(d[vctr_nm]))) chunks.append((d["content_with_weight"], np.array(d[vctr_nm])))


cron_logger.info("Indexing elapsed({}): {:.2f}".format(r["name"], timer() - st)) cron_logger.info("Indexing elapsed({}): {:.2f}".format(r["name"], timer() - st))
if es_r: if es_r:
callback(-1, f"Insert chunk error, detail info please check ragflow-logs/api/cron_logger.log. Please also check ES status!")
callback(-1, "Insert chunk error, detail info please check ragflow-logs/api/cron_logger.log. Please also check ES status!")
ELASTICSEARCH.deleteByQuery( ELASTICSEARCH.deleteByQuery(
Q("match", doc_id=r["doc_id"]), idxnm=search.index_name(r["tenant_id"])) Q("match", doc_id=r["doc_id"]), idxnm=search.index_name(r["tenant_id"]))
cron_logger.error(str(es_r)) cron_logger.error(str(es_r))




def report_status(): def report_status():
global CONSUMEER_NAME
global CONSUMER_NAME
while True: while True:
try: try:
obj = REDIS_CONN.get("TASKEXE") obj = REDIS_CONN.get("TASKEXE")
if not obj: obj = {} if not obj: obj = {}
else: obj = json.loads(obj) else: obj = json.loads(obj)
if CONSUMEER_NAME not in obj: obj[CONSUMEER_NAME] = []
obj[CONSUMEER_NAME].append(timer())
obj[CONSUMEER_NAME] = obj[CONSUMEER_NAME][-60:]
if CONSUMER_NAME not in obj: obj[CONSUMER_NAME] = []
obj[CONSUMER_NAME].append(timer())
obj[CONSUMER_NAME] = obj[CONSUMER_NAME][-60:]
REDIS_CONN.set_obj("TASKEXE", obj, 60*2) REDIS_CONN.set_obj("TASKEXE", obj, 60*2)
except Exception as e: except Exception as e:
print("[Exception]:", str(e)) print("[Exception]:", str(e))

Loading…
Cancel
Save