|
|
|
@@ -14,9 +14,10 @@ |
|
|
|
# limitations under the License. |
|
|
|
# |
|
|
|
import logging |
|
|
|
import inspect |
|
|
|
import sys |
|
|
|
from api.utils.log_utils import initRootLogger |
|
|
|
initRootLogger(inspect.getfile(inspect.currentframe())) |
|
|
|
CONSUMER_NO = "0" if len(sys.argv) < 2 else sys.argv[1] |
|
|
|
initRootLogger(f"task_executor_{CONSUMER_NO}") |
|
|
|
for module in ["pdfminer"]: |
|
|
|
module_logger = logging.getLogger(module) |
|
|
|
module_logger.setLevel(logging.WARNING) |
|
|
|
@@ -25,7 +26,7 @@ for module in ["peewee"]: |
|
|
|
module_logger.handlers.clear() |
|
|
|
module_logger.propagate = True |
|
|
|
|
|
|
|
import datetime |
|
|
|
from datetime import datetime |
|
|
|
import json |
|
|
|
import os |
|
|
|
import hashlib |
|
|
|
@@ -33,7 +34,7 @@ import copy |
|
|
|
import re |
|
|
|
import sys |
|
|
|
import time |
|
|
|
from concurrent.futures import ThreadPoolExecutor |
|
|
|
import threading |
|
|
|
from functools import partial |
|
|
|
from io import BytesIO |
|
|
|
from multiprocessing.context import TimeoutError |
|
|
|
@@ -78,9 +79,14 @@ FACTORY = { |
|
|
|
ParserType.KG.value: knowledge_graph |
|
|
|
} |
|
|
|
|
|
|
|
CONSUMER_NAME = "task_consumer_" + ("0" if len(sys.argv) < 2 else sys.argv[1]) |
|
|
|
CONSUMER_NAME = "task_consumer_" + CONSUMER_NO |
|
|
|
PAYLOAD: Payload | None = None |
|
|
|
|
|
|
|
BOOT_AT = datetime.now().isoformat() |
|
|
|
DONE_TASKS = 0 |
|
|
|
RETRY_TASKS = 0 |
|
|
|
PENDING_TASKS = 0 |
|
|
|
HEAD_CREATED_AT = "" |
|
|
|
HEAD_DETAIL = "" |
|
|
|
|
|
|
|
def set_progress(task_id, from_page=0, to_page=-1, prog=None, msg="Processing..."): |
|
|
|
global PAYLOAD |
|
|
|
@@ -199,8 +205,8 @@ def build(row): |
|
|
|
md5.update((ck["content_with_weight"] + |
|
|
|
str(d["doc_id"])).encode("utf-8")) |
|
|
|
d["id"] = md5.hexdigest() |
|
|
|
d["create_time"] = str(datetime.datetime.now()).replace("T", " ")[:19] |
|
|
|
d["create_timestamp_flt"] = datetime.datetime.now().timestamp() |
|
|
|
d["create_time"] = str(datetime.now()).replace("T", " ")[:19] |
|
|
|
d["create_timestamp_flt"] = datetime.now().timestamp() |
|
|
|
if not d.get("image"): |
|
|
|
d["img_id"] = "" |
|
|
|
d["page_num_list"] = json.dumps([]) |
|
|
|
@@ -333,8 +339,8 @@ def run_raptor(row, chat_mdl, embd_mdl, callback=None): |
|
|
|
md5 = hashlib.md5() |
|
|
|
md5.update((content + str(d["doc_id"])).encode("utf-8")) |
|
|
|
d["id"] = md5.hexdigest() |
|
|
|
d["create_time"] = str(datetime.datetime.now()).replace("T", " ")[:19] |
|
|
|
d["create_timestamp_flt"] = datetime.datetime.now().timestamp() |
|
|
|
d["create_time"] = str(datetime.now()).replace("T", " ")[:19] |
|
|
|
d["create_timestamp_flt"] = datetime.now().timestamp() |
|
|
|
d[vctr_nm] = vctr.tolist() |
|
|
|
d["content_with_weight"] = content |
|
|
|
d["content_ltks"] = rag_tokenizer.tokenize(content) |
|
|
|
@@ -403,7 +409,7 @@ def main(): |
|
|
|
|
|
|
|
logging.info("Indexing elapsed({}): {:.2f}".format(r["name"], timer() - st)) |
|
|
|
if es_r: |
|
|
|
callback(-1, f"Insert chunk error, detail info please check {LOG_FILE}. Please also check ES status!") |
|
|
|
callback(-1, "Insert chunk error, detail info please check log file. Please also check ES status!") |
|
|
|
docStoreConn.delete({"doc_id": r["doc_id"]}, search.index_name(r["tenant_id"]), r["kb_id"]) |
|
|
|
logging.error('Insert chunk error: ' + str(es_r)) |
|
|
|
else: |
|
|
|
@@ -420,24 +426,44 @@ def main(): |
|
|
|
|
|
|
|
|
|
|
|
def report_status(): |
|
|
|
global CONSUMER_NAME |
|
|
|
global CONSUMER_NAME, BOOT_AT, DONE_TASKS, RETRY_TASKS, PENDING_TASKS, HEAD_CREATED_AT, HEAD_DETAIL |
|
|
|
REDIS_CONN.sadd("TASKEXE", CONSUMER_NAME) |
|
|
|
while True: |
|
|
|
try: |
|
|
|
obj = REDIS_CONN.get("TASKEXE") |
|
|
|
if not obj: obj = {} |
|
|
|
else: obj = json.loads(obj) |
|
|
|
if CONSUMER_NAME not in obj: obj[CONSUMER_NAME] = [] |
|
|
|
obj[CONSUMER_NAME].append(timer()) |
|
|
|
obj[CONSUMER_NAME] = obj[CONSUMER_NAME][-60:] |
|
|
|
REDIS_CONN.set_obj("TASKEXE", obj, 60*2) |
|
|
|
now = datetime.now() |
|
|
|
PENDING_TASKS = REDIS_CONN.queue_length(SVR_QUEUE_NAME) |
|
|
|
if PENDING_TASKS > 0: |
|
|
|
head_info = REDIS_CONN.queue_head(SVR_QUEUE_NAME) |
|
|
|
if head_info is not None: |
|
|
|
seconds = int(head_info[0].split("-")[0])/1000 |
|
|
|
HEAD_CREATED_AT = datetime.fromtimestamp(seconds).isoformat() |
|
|
|
HEAD_DETAIL = head_info[1] |
|
|
|
|
|
|
|
heartbeat = json.dumps({ |
|
|
|
"name": CONSUMER_NAME, |
|
|
|
"now": now.isoformat(), |
|
|
|
"boot_at": BOOT_AT, |
|
|
|
"done": DONE_TASKS, |
|
|
|
"retry": RETRY_TASKS, |
|
|
|
"pending": PENDING_TASKS, |
|
|
|
"head_created_at": HEAD_CREATED_AT, |
|
|
|
"head_detail": HEAD_DETAIL, |
|
|
|
}) |
|
|
|
REDIS_CONN.zadd(CONSUMER_NAME, heartbeat, now.timestamp()) |
|
|
|
logging.info(f"{CONSUMER_NAME} reported heartbeat: {heartbeat}") |
|
|
|
|
|
|
|
expired = REDIS_CONN.zcount(CONSUMER_NAME, 0, now.timestamp() - 60*30) |
|
|
|
if expired > 0: |
|
|
|
REDIS_CONN.zpopmin(CONSUMER_NAME, expired) |
|
|
|
except Exception: |
|
|
|
logging.exception("report_status got exception") |
|
|
|
time.sleep(30) |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
exe = ThreadPoolExecutor(max_workers=1) |
|
|
|
exe.submit(report_status) |
|
|
|
background_thread = threading.Thread(target=report_status) |
|
|
|
background_thread.daemon = True |
|
|
|
background_thread.start() |
|
|
|
|
|
|
|
while True: |
|
|
|
main() |