Просмотр исходного кода

Optimize graphrag cache get entity (#6018)

### What problem does this PR solve?

Optimize graphrag cache get entity

### Type of change

- [x] Performance Improvement
tags/v0.17.2
Zhichang Yu 7 месяцев назад
Родитель
Сommit
e213873852
Аккаунт пользователя с таким Email не найден
2 измененных файлов: 58 добавлений и 24 удалений
  1. 27
    0
      graphrag/utils.py
  2. 31
    24
      rag/svr/task_executor.py

+ 27
- 0
graphrag/utils.py Просмотреть файл

@@ -237,8 +237,33 @@ def is_float_regex(value):
def chunk_id(chunk):
return xxhash.xxh64((chunk["content_with_weight"] + chunk["kb_id"]).encode("utf-8")).hexdigest()

def get_entity_cache(tenant_id, kb_id, ent_name) -> str | list[str]:
hasher = xxhash.xxh64()
hasher.update(str(tenant_id).encode("utf-8"))
hasher.update(str(kb_id).encode("utf-8"))
hasher.update(str(ent_name).encode("utf-8"))

k = hasher.hexdigest()
bin = REDIS_CONN.get(k)
if not bin:
return
return json.loads(bin)


def set_entity_cache(tenant_id, kb_id, ent_name, content_with_weight):
hasher = xxhash.xxh64()
hasher.update(str(tenant_id).encode("utf-8"))
hasher.update(str(kb_id).encode("utf-8"))
hasher.update(str(ent_name).encode("utf-8"))

k = hasher.hexdigest()
REDIS_CONN.set(k, content_with_weight.encode("utf-8"), 3600)


def get_entity(tenant_id, kb_id, ent_name):
cache = get_entity_cache(tenant_id, kb_id, ent_name)
if cache:
return cache
conds = {
"fields": ["content_with_weight"],
"entity_kwd": ent_name,
@@ -250,6 +275,7 @@ def get_entity(tenant_id, kb_id, ent_name):
for id in es_res.ids:
try:
if isinstance(ent_name, str):
set_entity_cache(tenant_id, kb_id, ent_name, es_res.field[id]["content_with_weight"])
return json.loads(es_res.field[id]["content_with_weight"])
res.append(json.loads(es_res.field[id]["content_with_weight"]))
except Exception:
@@ -272,6 +298,7 @@ def set_entity(tenant_id, kb_id, embd_mdl, ent_name, meta):
"available_int": 0
}
chunk["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(chunk["content_ltks"])
set_entity_cache(tenant_id, kb_id, ent_name, chunk["content_with_weight"])
res = settings.retrievaler.search({"entity_kwd": ent_name, "size": 1, "fields": []},
search.index_name(tenant_id), [kb_id])
if res.ids:

+ 31
- 24
rag/svr/task_executor.py Просмотреть файл

@@ -26,7 +26,6 @@ from rag.prompts import keyword_extraction, question_proposal, content_tagging

CONSUMER_NO = "0" if len(sys.argv) < 2 else sys.argv[1]
CONSUMER_NAME = "task_executor_" + CONSUMER_NO
initRootLogger(CONSUMER_NAME)

import logging
import os
@@ -43,6 +42,7 @@ import tracemalloc
import signal
import trio
import exceptiongroup
import faulthandler

import numpy as np
from peewee import DoesNotExist
@@ -139,30 +139,35 @@ class TaskCanceledException(Exception):


def set_progress(task_id, from_page=0, to_page=-1, prog=None, msg="Processing..."):
if prog is not None and prog < 0:
msg = "[ERROR]" + msg
cancel = TaskService.do_cancel(task_id)

if cancel:
msg += " [Canceled]"
prog = -1

if to_page > 0:
try:
if prog is not None and prog < 0:
msg = "[ERROR]" + msg
cancel = TaskService.do_cancel(task_id)

if cancel:
msg += " [Canceled]"
prog = -1

if to_page > 0:
if msg:
if from_page < to_page:
msg = f"Page({from_page + 1}~{to_page + 1}): " + msg
if msg:
if from_page < to_page:
msg = f"Page({from_page + 1}~{to_page + 1}): " + msg
if msg:
msg = datetime.now().strftime("%H:%M:%S") + " " + msg
d = {"progress_msg": msg}
if prog is not None:
d["progress"] = prog

logging.info(f"set_progress({task_id}), progress: {prog}, progress_msg: {msg}")
TaskService.update_progress(task_id, d)

close_connection()
if cancel:
raise TaskCanceledException(msg)
msg = datetime.now().strftime("%H:%M:%S") + " " + msg
d = {"progress_msg": msg}
if prog is not None:
d["progress"] = prog

TaskService.update_progress(task_id, d)

close_connection()
if cancel:
raise TaskCanceledException(msg)
logging.info(f"set_progress({task_id}), progress: {prog}, progress_msg: {msg}")
except DoesNotExist:
logging.warning(f"set_progress({task_id}) got exception DoesNotExist")
except Exception:
logging.exception(f"set_progress({task_id}), progress: {prog}, progress_msg: {msg}, got exception")

async def collect():
global CONSUMER_NAME, DONE_TASKS, FAILED_TASKS
@@ -664,4 +669,6 @@ async def main():
logging.error("BUG!!! You should not reach here!!!")

if __name__ == "__main__":
faulthandler.enable()
initRootLogger(CONSUMER_NAME)
trio.run(main)

Загрузка…
Отмена
Сохранить