|
|
|
|
|
|
|
|
# beartype_all(conf=BeartypeConf(violation_type=UserWarning)) # <-- emit warnings from all code |
|
|
# beartype_all(conf=BeartypeConf(violation_type=UserWarning)) # <-- emit warnings from all code |
|
|
import random |
|
|
import random |
|
|
import sys |
|
|
import sys |
|
|
|
|
|
import threading |
|
|
|
|
|
import time |
|
|
|
|
|
|
|
|
from api.utils.log_utils import initRootLogger, get_project_base_directory |
|
|
from api.utils.log_utils import initRootLogger, get_project_base_directory |
|
|
from graphrag.general.index import run_graphrag |
|
|
from graphrag.general.index import run_graphrag |
|
|
|
|
|
|
|
|
from rag.raptor import RecursiveAbstractiveProcessing4TreeOrganizedRetrieval as Raptor |
|
|
from rag.raptor import RecursiveAbstractiveProcessing4TreeOrganizedRetrieval as Raptor |
|
|
from rag.settings import DOC_MAXIMUM_SIZE, SVR_CONSUMER_GROUP_NAME, get_svr_queue_name, get_svr_queue_names, print_rag_settings, TAG_FLD, PAGERANK_FLD |
|
|
from rag.settings import DOC_MAXIMUM_SIZE, SVR_CONSUMER_GROUP_NAME, get_svr_queue_name, get_svr_queue_names, print_rag_settings, TAG_FLD, PAGERANK_FLD |
|
|
from rag.utils import num_tokens_from_string, truncate |
|
|
from rag.utils import num_tokens_from_string, truncate |
|
|
from rag.utils.redis_conn import REDIS_CONN |
|
|
|
|
|
|
|
|
from rag.utils.redis_conn import REDIS_CONN, RedisDistributedLock |
|
|
from rag.utils.storage_factory import STORAGE_IMPL |
|
|
from rag.utils.storage_factory import STORAGE_IMPL |
|
|
from graphrag.utils import chat_limiter |
|
|
from graphrag.utils import chat_limiter |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MAX_CONCURRENT_CHUNK_BUILDERS = int(os.environ.get('MAX_CONCURRENT_CHUNK_BUILDERS', "1")) |
|
|
MAX_CONCURRENT_CHUNK_BUILDERS = int(os.environ.get('MAX_CONCURRENT_CHUNK_BUILDERS', "1")) |
|
|
task_limiter = trio.CapacityLimiter(MAX_CONCURRENT_TASKS) |
|
|
task_limiter = trio.CapacityLimiter(MAX_CONCURRENT_TASKS) |
|
|
chunk_limiter = trio.CapacityLimiter(MAX_CONCURRENT_CHUNK_BUILDERS) |
|
|
chunk_limiter = trio.CapacityLimiter(MAX_CONCURRENT_CHUNK_BUILDERS) |
|
|
|
|
|
WORKER_HEARTBEAT_TIMEOUT = int(os.environ.get('WORKER_HEARTBEAT_TIMEOUT', '120')) |
|
|
|
|
|
stop_event = threading.Event() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def signal_handler(sig, frame): |
|
|
|
|
|
logging.info("Received interrupt signal, shutting down...") |
|
|
|
|
|
stop_event.set() |
|
|
|
|
|
time.sleep(1) |
|
|
|
|
|
sys.exit(0) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# SIGUSR1 handler: start tracemalloc and take snapshot |
|
|
# SIGUSR1 handler: start tracemalloc and take snapshot |
|
|
|
|
|
|
|
|
async def report_status(): |
|
|
async def report_status(): |
|
|
global CONSUMER_NAME, BOOT_AT, PENDING_TASKS, LAG_TASKS, DONE_TASKS, FAILED_TASKS |
|
|
global CONSUMER_NAME, BOOT_AT, PENDING_TASKS, LAG_TASKS, DONE_TASKS, FAILED_TASKS |
|
|
REDIS_CONN.sadd("TASKEXE", CONSUMER_NAME) |
|
|
REDIS_CONN.sadd("TASKEXE", CONSUMER_NAME) |
|
|
|
|
|
redis_lock = RedisDistributedLock("clean_task_executor", lock_value=CONSUMER_NAME, timeout=60) |
|
|
while True: |
|
|
while True: |
|
|
try: |
|
|
try: |
|
|
now = datetime.now() |
|
|
now = datetime.now() |
|
|
|
|
|
|
|
|
expired = REDIS_CONN.zcount(CONSUMER_NAME, 0, now.timestamp() - 60 * 30) |
|
|
expired = REDIS_CONN.zcount(CONSUMER_NAME, 0, now.timestamp() - 60 * 30) |
|
|
if expired > 0: |
|
|
if expired > 0: |
|
|
REDIS_CONN.zpopmin(CONSUMER_NAME, expired) |
|
|
REDIS_CONN.zpopmin(CONSUMER_NAME, expired) |
|
|
|
|
|
|
|
|
|
|
|
# clean task executor |
|
|
|
|
|
if redis_lock.acquire(): |
|
|
|
|
|
task_executors = REDIS_CONN.smembers("TASKEXE") |
|
|
|
|
|
for consumer_name in task_executors: |
|
|
|
|
|
if consumer_name == CONSUMER_NAME: |
|
|
|
|
|
continue |
|
|
|
|
|
expired = REDIS_CONN.zcount( |
|
|
|
|
|
consumer_name, now.timestamp() - WORKER_HEARTBEAT_TIMEOUT, now.timestamp() + 10 |
|
|
|
|
|
) |
|
|
|
|
|
if expired == 0: |
|
|
|
|
|
logging.info(f"{consumer_name} expired, removed") |
|
|
|
|
|
REDIS_CONN.srem("TASKEXE", consumer_name) |
|
|
|
|
|
REDIS_CONN.delete(consumer_name) |
|
|
except Exception: |
|
|
except Exception: |
|
|
logging.exception("report_status got exception") |
|
|
logging.exception("report_status got exception") |
|
|
await trio.sleep(30) |
|
|
await trio.sleep(30) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def recover_pending_tasks(): |
|
|
|
|
|
redis_lock = RedisDistributedLock("recover_pending_tasks", lock_value=CONSUMER_NAME, timeout=60) |
|
|
|
|
|
svr_queue_names = get_svr_queue_names() |
|
|
|
|
|
while not stop_event.is_set(): |
|
|
|
|
|
try: |
|
|
|
|
|
if redis_lock.acquire(): |
|
|
|
|
|
for queue_name in svr_queue_names: |
|
|
|
|
|
msgs = REDIS_CONN.get_pending_msg(queue=queue_name, group_name=SVR_CONSUMER_GROUP_NAME) |
|
|
|
|
|
msgs = [msg for msg in msgs if msg['consumer'] != CONSUMER_NAME] |
|
|
|
|
|
if len(msgs) == 0: |
|
|
|
|
|
continue |
|
|
|
|
|
|
|
|
|
|
|
task_executors = REDIS_CONN.smembers("TASKEXE") |
|
|
|
|
|
task_executor_set = {t for t in task_executors} |
|
|
|
|
|
msgs = [msg for msg in msgs if msg['consumer'] not in task_executor_set] |
|
|
|
|
|
for msg in msgs: |
|
|
|
|
|
logging.info( |
|
|
|
|
|
f"Recover pending task: {msg['message_id']}, consumer: {msg['consumer']}, " |
|
|
|
|
|
f"time since delivered: {msg['time_since_delivered'] / 1000} s" |
|
|
|
|
|
) |
|
|
|
|
|
REDIS_CONN.requeue_msg(queue_name, SVR_CONSUMER_GROUP_NAME, msg['message_id']) |
|
|
|
|
|
|
|
|
|
|
|
stop_event.wait(60) |
|
|
|
|
|
except Exception: |
|
|
|
|
|
logging.warning("recover_pending_tasks got exception") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def main(): |
|
|
async def main(): |
|
|
logging.info(r""" |
|
|
logging.info(r""" |
|
|
______ __ ______ __ |
|
|
______ __ ______ __ |
|
|
|
|
|
|
|
|
if TRACE_MALLOC_ENABLED: |
|
|
if TRACE_MALLOC_ENABLED: |
|
|
start_tracemalloc_and_snapshot(None, None) |
|
|
start_tracemalloc_and_snapshot(None, None) |
|
|
|
|
|
|
|
|
|
|
|
signal.signal(signal.SIGINT, signal_handler) |
|
|
|
|
|
signal.signal(signal.SIGTERM, signal_handler) |
|
|
|
|
|
|
|
|
|
|
|
threading.Thread(name="RecoverPendingTask", target=recover_pending_tasks).start() |
|
|
|
|
|
|
|
|
async with trio.open_nursery() as nursery: |
|
|
async with trio.open_nursery() as nursery: |
|
|
nursery.start_soon(report_status) |
|
|
nursery.start_soon(report_status) |
|
|
while True: |
|
|
|
|
|
|
|
|
while not stop_event.is_set(): |
|
|
async with task_limiter: |
|
|
async with task_limiter: |
|
|
nursery.start_soon(handle_task) |
|
|
nursery.start_soon(handle_task) |
|
|
logging.error("BUG!!! You should not reach here!!!") |
|
|
logging.error("BUG!!! You should not reach here!!!") |