Browse Source

Refactor graphrag to remove redis lock (#5828)

### What problem does this PR solve?

Refactor graphrag to remove redis lock

### Type of change

- [x] Refactoring
tags/v0.17.1
Zhichang Yu 7 months ago
parent
commit
6ec6ca6971
No account linked to committer's email address

+ 6
- 0
api/ragflow_server.py View File

from api.versions import get_ragflow_version from api.versions import get_ragflow_version
from api.utils import show_configs from api.utils import show_configs
from rag.settings import print_rag_settings from rag.settings import print_rag_settings
from rag.utils.redis_conn import RedisDistributedLock


stop_event = threading.Event() stop_event = threading.Event()


def update_progress(): def update_progress():
redis_lock = RedisDistributedLock("update_progress", timeout=60)
while not stop_event.is_set(): while not stop_event.is_set():
try: try:
if not redis_lock.acquire():
continue
DocumentService.update_progress() DocumentService.update_progress()
stop_event.wait(6) stop_event.wait(6)
except Exception: except Exception:
logging.exception("update_progress exception") logging.exception("update_progress exception")
finally:
redis_lock.release()


def signal_handler(sig, frame): def signal_handler(sig, frame):
logging.info("Received interrupt signal, shutting down...") logging.info("Received interrupt signal, shutting down...")

+ 7
- 4
graphrag/general/extractor.py View File

return dict(maybe_nodes), dict(maybe_edges) return dict(maybe_nodes), dict(maybe_edges)


async def __call__( async def __call__(
self, chunks: list[tuple[str, str]],
self, doc_id: str, chunks: list[str],
callback: Callable | None = None callback: Callable | None = None
): ):


start_ts = trio.current_time() start_ts = trio.current_time()
out_results = [] out_results = []
async with trio.open_nursery() as nursery: async with trio.open_nursery() as nursery:
for i, (cid, ck) in enumerate(chunks):
for i, ck in enumerate(chunks):
ck = truncate(ck, int(self._llm.max_length*0.8)) ck = truncate(ck, int(self._llm.max_length*0.8))
nursery.start_soon(lambda: self._process_single_content((cid, ck), i, len(chunks), out_results))
nursery.start_soon(lambda: self._process_single_content((doc_id, ck), i, len(chunks), out_results))


maybe_nodes = defaultdict(list) maybe_nodes = defaultdict(list)
maybe_edges = defaultdict(list) maybe_edges = defaultdict(list)
) -> str: ) -> str:
summary_max_tokens = 512 summary_max_tokens = 512
use_description = truncate(description, summary_max_tokens) use_description = truncate(description, summary_max_tokens)
description_list=use_description.split(GRAPH_FIELD_SEP),
if len(description_list) <= 12:
return use_description
prompt_template = SUMMARIZE_DESCRIPTIONS_PROMPT prompt_template = SUMMARIZE_DESCRIPTIONS_PROMPT
context_base = dict( context_base = dict(
entity_name=entity_or_relation_name, entity_name=entity_or_relation_name,
description_list=use_description.split(GRAPH_FIELD_SEP),
description_list=description_list,
language=self._language, language=self._language,
) )
use_prompt = prompt_template.format(**context_base) use_prompt = prompt_template.format(**context_base)

+ 338
- 181
graphrag/general/index.py View File

# #
import json import json
import logging import logging
from functools import reduce, partial
from functools import partial
import networkx as nx import networkx as nx
import trio import trio


from api import settings from api import settings
from graphrag.light.graph_extractor import GraphExtractor as LightKGExt
from graphrag.general.graph_extractor import GraphExtractor as GeneralKGExt
from graphrag.general.community_reports_extractor import CommunityReportsExtractor from graphrag.general.community_reports_extractor import CommunityReportsExtractor
from graphrag.entity_resolution import EntityResolution from graphrag.entity_resolution import EntityResolution
from graphrag.general.extractor import Extractor from graphrag.general.extractor import Extractor
from graphrag.general.graph_extractor import DEFAULT_ENTITY_TYPES
from graphrag.utils import graph_merge, set_entity, get_relation, set_relation, get_entity, get_graph, set_graph, \
chunk_id, update_nodes_pagerank_nhop_neighbour
from graphrag.utils import (
graph_merge,
set_entity,
get_relation,
set_relation,
get_entity,
get_graph,
set_graph,
chunk_id,
update_nodes_pagerank_nhop_neighbour,
does_graph_contains,
get_graph_doc_ids,
)
from rag.nlp import rag_tokenizer, search from rag.nlp import rag_tokenizer, search
from rag.utils.redis_conn import RedisDistributedLock


class Dealer:
def __init__(self,
extractor: Extractor,
tenant_id: str,
kb_id: str,
llm_bdl,
chunks: list[tuple[str, str]],
language,
entity_types=DEFAULT_ENTITY_TYPES,
embed_bdl=None,
callback=None
):
self.tenant_id = tenant_id
self.kb_id = kb_id
self.chunks = chunks
self.llm_bdl = llm_bdl
self.embed_bdl = embed_bdl
self.ext = extractor(self.llm_bdl, language=language,
entity_types=entity_types,
get_entity=partial(get_entity, tenant_id, kb_id),
set_entity=partial(set_entity, tenant_id, kb_id, self.embed_bdl),
get_relation=partial(get_relation, tenant_id, kb_id),
set_relation=partial(set_relation, tenant_id, kb_id, self.embed_bdl)
)
self.graph = nx.Graph()
self.callback = callback

async def __call__(self):
docids = list(set([docid for docid, _ in self.chunks]))
ents, rels = await self.ext(self.chunks, self.callback)
for en in ents:
self.graph.add_node(en["entity_name"], entity_type=en["entity_type"])#, description=en["description"])

for rel in rels:
self.graph.add_edge(
rel["src_id"],
rel["tgt_id"],
weight=rel["weight"],
#description=rel["description"]
from rag.utils.redis_conn import REDIS_CONN


def graphrag_task_set(tenant_id, kb_id, doc_id) -> bool:
key = f"graphrag:{tenant_id}:{kb_id}"
ok = REDIS_CONN.set(key, doc_id, exp=3600 * 24)
if not ok:
raise Exception(f"Faild to set the {key} to {doc_id}")


def graphrag_task_get(tenant_id, kb_id) -> str | None:
key = f"graphrag:{tenant_id}:{kb_id}"
doc_id = REDIS_CONN.get(key)
return doc_id


async def run_graphrag(
row: dict,
language,
with_resolution: bool,
with_community: bool,
chat_model,
embedding_model,
callback,
):
start = trio.current_time()
tenant_id, kb_id, doc_id = row["tenant_id"], str(row["kb_id"]), row["doc_id"]
chunks = []
for d in settings.retrievaler.chunk_list(
doc_id, tenant_id, [kb_id], fields=["content_with_weight", "doc_id"]
):
chunks.append(d["content_with_weight"])

graph, doc_ids = await update_graph(
LightKGExt
if row["parser_config"]["graphrag"]["method"] != "general"
else GeneralKGExt,
tenant_id,
kb_id,
doc_id,
chunks,
language,
row["parser_config"]["graphrag"]["entity_types"],
chat_model,
embedding_model,
callback,
)
if not graph:
return
if with_resolution or with_community:
graphrag_task_set(tenant_id, kb_id, doc_id)
if with_resolution:
await resolve_entities(
graph,
doc_ids,
tenant_id,
kb_id,
doc_id,
chat_model,
embedding_model,
callback,
)
if with_community:
await extract_community(
graph,
doc_ids,
tenant_id,
kb_id,
doc_id,
chat_model,
embedding_model,
callback,
)
now = trio.current_time()
callback(msg=f"GraphRAG for doc {doc_id} done in {now - start:.2f} seconds.")
return


async def update_graph(
extractor: Extractor,
tenant_id: str,
kb_id: str,
doc_id: str,
chunks: list[str],
language,
entity_types,
llm_bdl,
embed_bdl,
callback,
):
contains = await does_graph_contains(tenant_id, kb_id, doc_id)
if contains:
callback(msg=f"Graph already contains {doc_id}, cancel myself")
return None, None
start = trio.current_time()
ext = extractor(
llm_bdl,
language=language,
entity_types=entity_types,
get_entity=partial(get_entity, tenant_id, kb_id),
set_entity=partial(set_entity, tenant_id, kb_id, embed_bdl),
get_relation=partial(get_relation, tenant_id, kb_id),
set_relation=partial(set_relation, tenant_id, kb_id, embed_bdl),
)
ents, rels = await ext(doc_id, chunks, callback)
subgraph = nx.Graph()
for en in ents:
subgraph.add_node(en["entity_name"], entity_type=en["entity_type"])

for rel in rels:
subgraph.add_edge(
rel["src_id"],
rel["tgt_id"],
weight=rel["weight"],
# description=rel["description"]
)
# TODO: infinity doesn't support array search
chunk = {
"content_with_weight": json.dumps(
nx.node_link_data(subgraph, edges="edges"), ensure_ascii=False, indent=2
),
"knowledge_graph_kwd": "subgraph",
"kb_id": kb_id,
"source_id": [doc_id],
"available_int": 0,
"removed_kwd": "N",
}
cid = chunk_id(chunk)
await trio.to_thread.run_sync(
lambda: settings.docStoreConn.insert(
[{"id": cid, **chunk}], search.index_name(tenant_id), kb_id
)
)
now = trio.current_time()
callback(msg=f"generated subgraph for doc {doc_id} in {now - start:.2f} seconds.")
start = now

while True:
new_graph = subgraph
now_docids = set([doc_id])
old_graph, old_doc_ids = await get_graph(tenant_id, kb_id)
if old_graph is not None:
logging.info("Merge with an exiting graph...................")
new_graph = graph_merge(old_graph, subgraph)
await update_nodes_pagerank_nhop_neighbour(tenant_id, kb_id, new_graph, 2)
if old_doc_ids:
for old_doc_id in old_doc_ids:
now_docids.add(old_doc_id)
old_doc_ids2 = await get_graph_doc_ids(tenant_id, kb_id)
delta_doc_ids = set(old_doc_ids2) - set(old_doc_ids)
if delta_doc_ids:
callback(
msg="The global graph has changed during merging, try again"
) )
await trio.sleep(1)
continue
break
await set_graph(tenant_id, kb_id, new_graph, list(now_docids))
now = trio.current_time()
callback(
msg=f"merging subgraph for doc {doc_id} into the global graph done in {now - start:.2f} seconds."
)
return new_graph, now_docids


async def resolve_entities(
graph,
doc_ids,
tenant_id: str,
kb_id: str,
doc_id: str,
llm_bdl,
embed_bdl,
callback,
):
working_doc_id = graphrag_task_get(tenant_id, kb_id)
if doc_id != working_doc_id:
callback(
msg=f"Another graphrag task of doc_id {working_doc_id} is working on this kb, cancel myself"
)
return
start = trio.current_time()
er = EntityResolution(
llm_bdl,
get_entity=partial(get_entity, tenant_id, kb_id),
set_entity=partial(set_entity, tenant_id, kb_id, embed_bdl),
get_relation=partial(get_relation, tenant_id, kb_id),
set_relation=partial(set_relation, tenant_id, kb_id, embed_bdl),
)
reso = await er(graph)
graph = reso.graph
callback(msg=f"Graph resolution removed {len(reso.removed_entities)} nodes.")
await update_nodes_pagerank_nhop_neighbour(tenant_id, kb_id, graph, 2)
callback(msg="Graph resolution updated pagerank.")

working_doc_id = graphrag_task_get(tenant_id, kb_id)
if doc_id != working_doc_id:
callback(
msg=f"Another graphrag task of doc_id {working_doc_id} is working on this kb, cancel myself"
)
return
await set_graph(tenant_id, kb_id, graph, doc_ids)


with RedisDistributedLock(self.kb_id, 60*60):
old_graph, old_doc_ids = get_graph(self.tenant_id, self.kb_id)
if old_graph is not None:
logging.info("Merge with an exiting graph...................")
self.graph = reduce(graph_merge, [old_graph, self.graph])
update_nodes_pagerank_nhop_neighbour(self.tenant_id, self.kb_id, self.graph, 2)
if old_doc_ids:
docids.extend(old_doc_ids)
docids = list(set(docids))
set_graph(self.tenant_id, self.kb_id, self.graph, docids)


class WithResolution(Dealer):
def __init__(self,
tenant_id: str,
kb_id: str,
llm_bdl,
embed_bdl=None,
callback=None
):
self.tenant_id = tenant_id
self.kb_id = kb_id
self.llm_bdl = llm_bdl
self.embed_bdl = embed_bdl
self.callback = callback
async def __call__(self):
with RedisDistributedLock(self.kb_id, 60*60):
self.graph, doc_ids = await trio.to_thread.run_sync(lambda: get_graph(self.tenant_id, self.kb_id))
if not self.graph:
logging.error(f"Faild to fetch the graph. tenant_id:{self.kb_id}, kb_id:{self.kb_id}")
if self.callback:
self.callback(-1, msg="Faild to fetch the graph.")
return

if self.callback:
self.callback(msg="Fetch the existing graph.")
er = EntityResolution(self.llm_bdl,
get_entity=partial(get_entity, self.tenant_id, self.kb_id),
set_entity=partial(set_entity, self.tenant_id, self.kb_id, self.embed_bdl),
get_relation=partial(get_relation, self.tenant_id, self.kb_id),
set_relation=partial(set_relation, self.tenant_id, self.kb_id, self.embed_bdl))
reso = await er(self.graph)
self.graph = reso.graph
logging.info("Graph resolution is done. Remove {} nodes.".format(len(reso.removed_entities)))
if self.callback:
self.callback(msg="Graph resolution is done. Remove {} nodes.".format(len(reso.removed_entities)))
await trio.to_thread.run_sync(lambda: update_nodes_pagerank_nhop_neighbour(self.tenant_id, self.kb_id, self.graph, 2))
await trio.to_thread.run_sync(lambda: set_graph(self.tenant_id, self.kb_id, self.graph, doc_ids))

await trio.to_thread.run_sync(lambda: settings.docStoreConn.delete({
"knowledge_graph_kwd": "relation",
"kb_id": self.kb_id,
"from_entity_kwd": reso.removed_entities
}, search.index_name(self.tenant_id), self.kb_id))
await trio.to_thread.run_sync(lambda: settings.docStoreConn.delete({
"knowledge_graph_kwd": "relation",
"kb_id": self.kb_id,
"to_entity_kwd": reso.removed_entities
}, search.index_name(self.tenant_id), self.kb_id))
await trio.to_thread.run_sync(lambda: settings.docStoreConn.delete({
"knowledge_graph_kwd": "entity",
"kb_id": self.kb_id,
"entity_kwd": reso.removed_entities
}, search.index_name(self.tenant_id), self.kb_id))


class WithCommunity(Dealer):
def __init__(self,
tenant_id: str,
kb_id: str,
llm_bdl,
embed_bdl=None,
callback=None
):

self.tenant_id = tenant_id
self.kb_id = kb_id
self.community_structure = None
self.community_reports = None
self.llm_bdl = llm_bdl
self.embed_bdl = embed_bdl
self.callback = callback
async def __call__(self):
with RedisDistributedLock(self.kb_id, 60*60):
self.graph, doc_ids = get_graph(self.tenant_id, self.kb_id)
if not self.graph:
logging.error(f"Faild to fetch the graph. tenant_id:{self.kb_id}, kb_id:{self.kb_id}")
if self.callback:
self.callback(-1, msg="Faild to fetch the graph.")
return
if self.callback:
self.callback(msg="Fetch the existing graph.")

cr = CommunityReportsExtractor(self.llm_bdl,
get_entity=partial(get_entity, self.tenant_id, self.kb_id),
set_entity=partial(set_entity, self.tenant_id, self.kb_id, self.embed_bdl),
get_relation=partial(get_relation, self.tenant_id, self.kb_id),
set_relation=partial(set_relation, self.tenant_id, self.kb_id, self.embed_bdl))
cr = await cr(self.graph, callback=self.callback)
self.community_structure = cr.structured_output
self.community_reports = cr.output
await trio.to_thread.run_sync(lambda: set_graph(self.tenant_id, self.kb_id, self.graph, doc_ids))

if self.callback:
self.callback(msg="Graph community extraction is done. Indexing {} reports.".format(len(cr.structured_output)))

await trio.to_thread.run_sync(lambda: settings.docStoreConn.delete({
await trio.to_thread.run_sync(
lambda: settings.docStoreConn.delete(
{
"knowledge_graph_kwd": "relation",
"kb_id": kb_id,
"from_entity_kwd": reso.removed_entities,
},
search.index_name(tenant_id),
kb_id,
)
)
await trio.to_thread.run_sync(
lambda: settings.docStoreConn.delete(
{
"knowledge_graph_kwd": "relation",
"kb_id": kb_id,
"to_entity_kwd": reso.removed_entities,
},
search.index_name(tenant_id),
kb_id,
)
)
await trio.to_thread.run_sync(
lambda: settings.docStoreConn.delete(
{
"knowledge_graph_kwd": "entity",
"kb_id": kb_id,
"entity_kwd": reso.removed_entities,
},
search.index_name(tenant_id),
kb_id,
)
)
now = trio.current_time()
callback(msg=f"Graph resolution done in {now - start:.2f}s.")


async def extract_community(
graph,
doc_ids,
tenant_id: str,
kb_id: str,
doc_id: str,
llm_bdl,
embed_bdl,
callback,
):
working_doc_id = graphrag_task_get(tenant_id, kb_id)
if doc_id != working_doc_id:
callback(
msg=f"Another graphrag task of doc_id {working_doc_id} is working on this kb, cancel myself"
)
return
start = trio.current_time()
ext = CommunityReportsExtractor(
llm_bdl,
get_entity=partial(get_entity, tenant_id, kb_id),
set_entity=partial(set_entity, tenant_id, kb_id, embed_bdl),
get_relation=partial(get_relation, tenant_id, kb_id),
set_relation=partial(set_relation, tenant_id, kb_id, embed_bdl),
)
cr = await ext(graph, callback=callback)
community_structure = cr.structured_output
community_reports = cr.output
working_doc_id = graphrag_task_get(tenant_id, kb_id)
if doc_id != working_doc_id:
callback(
msg=f"Another graphrag task of doc_id {working_doc_id} is working on this kb, cancel myself"
)
return
await set_graph(tenant_id, kb_id, graph, doc_ids)

now = trio.current_time()
callback(
msg=f"Graph extracted {len(cr.structured_output)} communities in {now - start:.2f}s."
)
start = now
await trio.to_thread.run_sync(
lambda: settings.docStoreConn.delete(
{"knowledge_graph_kwd": "community_report", "kb_id": kb_id},
search.index_name(tenant_id),
kb_id,
)
)
for stru, rep in zip(community_structure, community_reports):
obj = {
"report": rep,
"evidences": "\n".join([f["explanation"] for f in stru["findings"]]),
}
chunk = {
"docnm_kwd": stru["title"],
"title_tks": rag_tokenizer.tokenize(stru["title"]),
"content_with_weight": json.dumps(obj, ensure_ascii=False),
"content_ltks": rag_tokenizer.tokenize(
obj["report"] + " " + obj["evidences"]
),
"knowledge_graph_kwd": "community_report", "knowledge_graph_kwd": "community_report",
"kb_id": self.kb_id
}, search.index_name(self.tenant_id), self.kb_id))

for stru, rep in zip(self.community_structure, self.community_reports):
obj = {
"report": rep,
"evidences": "\n".join([f["explanation"] for f in stru["findings"]])
}
chunk = {
"docnm_kwd": stru["title"],
"title_tks": rag_tokenizer.tokenize(stru["title"]),
"content_with_weight": json.dumps(obj, ensure_ascii=False),
"content_ltks": rag_tokenizer.tokenize(obj["report"] +" "+ obj["evidences"]),
"knowledge_graph_kwd": "community_report",
"weight_flt": stru["weight"],
"entities_kwd": stru["entities"],
"important_kwd": stru["entities"],
"kb_id": self.kb_id,
"source_id": doc_ids,
"available_int": 0
}
chunk["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(chunk["content_ltks"])
#try:
# ebd, _ = self.embed_bdl.encode([", ".join(community["entities"])])
# chunk["q_%d_vec" % len(ebd[0])] = ebd[0]
#except Exception as e:
# logging.exception(f"Fail to embed entity relation: {e}")
await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert([{"id": chunk_id(chunk), **chunk}], search.index_name(self.tenant_id)))
"weight_flt": stru["weight"],
"entities_kwd": stru["entities"],
"important_kwd": stru["entities"],
"kb_id": kb_id,
"source_id": doc_ids,
"available_int": 0,
}
chunk["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(
chunk["content_ltks"]
)
# try:
# ebd, _ = embed_bdl.encode([", ".join(community["entities"])])
# chunk["q_%d_vec" % len(ebd[0])] = ebd[0]
# except Exception as e:
# logging.exception(f"Fail to embed entity relation: {e}")
await trio.to_thread.run_sync(
lambda: settings.docStoreConn.insert(
[{"id": chunk_id(chunk), **chunk}], search.index_name(tenant_id)
)
)


now = trio.current_time()
callback(
msg=f"Graph indexed {len(cr.structured_output)} communities in {now - start:.2f}s."
)
return community_structure, community_reports

+ 65
- 22
graphrag/general/smoke.py View File



import argparse import argparse
import json import json
import logging
import networkx as nx import networkx as nx
import trio import trio


from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMBundle from api.db.services.llm_service import LLMBundle
from api.db.services.user_service import TenantService from api.db.services.user_service import TenantService
from graphrag.general.index import WithCommunity, Dealer, WithResolution
from graphrag.light.graph_extractor import GraphExtractor
from rag.utils.redis_conn import RedisDistributedLock
from graphrag.general.graph_extractor import GraphExtractor
from graphrag.general.index import update_graph, with_resolution, with_community


settings.init_settings() settings.init_settings()


if __name__ == "__main__":

def callback(prog=None, msg="Processing..."):
logging.info(msg)


async def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-t', '--tenant_id', default=False, help="Tenant ID", action='store', required=True)
parser.add_argument('-d', '--doc_id', default=False, help="Document ID", action='store', required=True)
parser.add_argument(
"-t",
"--tenant_id",
default=False,
help="Tenant ID",
action="store",
required=True,
)
parser.add_argument(
"-d",
"--doc_id",
default=False,
help="Document ID",
action="store",
required=True,
)
args = parser.parse_args() args = parser.parse_args()
e, doc = DocumentService.get_by_id(args.doc_id) e, doc = DocumentService.get_by_id(args.doc_id)
if not e: if not e:
raise LookupError("Document not found.") raise LookupError("Document not found.")
kb_id = doc.kb_id kb_id = doc.kb_id


chunks = [d["content_with_weight"] for d in
settings.retrievaler.chunk_list(args.doc_id, args.tenant_id, [kb_id], max_count=6,
fields=["content_with_weight"])]
chunks = [("x", c) for c in chunks]

RedisDistributedLock.clean_lock(kb_id)
chunks = [
d["content_with_weight"]
for d in settings.retrievaler.chunk_list(
args.doc_id,
args.tenant_id,
[kb_id],
max_count=6,
fields=["content_with_weight"],
)
]


_, tenant = TenantService.get_by_id(args.tenant_id) _, tenant = TenantService.get_by_id(args.tenant_id)
llm_bdl = LLMBundle(args.tenant_id, LLMType.CHAT, tenant.llm_id) llm_bdl = LLMBundle(args.tenant_id, LLMType.CHAT, tenant.llm_id)
_, kb = KnowledgebaseService.get_by_id(kb_id) _, kb = KnowledgebaseService.get_by_id(kb_id)
embed_bdl = LLMBundle(args.tenant_id, LLMType.EMBEDDING, kb.embd_id) embed_bdl = LLMBundle(args.tenant_id, LLMType.EMBEDDING, kb.embd_id)


dealer = Dealer(GraphExtractor, args.tenant_id, kb_id, llm_bdl, chunks, "English", embed_bdl=embed_bdl)
trio.run(dealer())
print(json.dumps(nx.node_link_data(dealer.graph), ensure_ascii=False, indent=2))
graph, doc_ids = await update_graph(
GraphExtractor,
args.tenant_id,
kb_id,
args.doc_id,
chunks,
"English",
llm_bdl,
embed_bdl,
callback,
)
print(json.dumps(nx.node_link_data(graph), ensure_ascii=False, indent=2))

await with_resolution(
args.tenant_id, kb_id, args.doc_id, llm_bdl, embed_bdl, callback
)
community_structure, community_reports = await with_community(
args.tenant_id, kb_id, args.doc_id, llm_bdl, embed_bdl, callback
)


dealer = WithResolution(args.tenant_id, kb_id, llm_bdl, embed_bdl)
trio.run(dealer())
dealer = WithCommunity(args.tenant_id, kb_id, llm_bdl, embed_bdl)
trio.run(dealer())
print(
"------------------ COMMUNITY STRUCTURE--------------------\n",
json.dumps(community_structure, ensure_ascii=False, indent=2),
)
print(
"------------------ COMMUNITY REPORTS----------------------\n",
community_reports,
)


print("------------------ COMMUNITY REPORT ----------------------\n", dealer.community_reports)
print(json.dumps(dealer.community_structure, ensure_ascii=False, indent=2))

if __name__ == "__main__":
trio.run(main)

+ 51
- 13
graphrag/light/smoke.py View File

import json import json
from api import settings from api import settings
import networkx as nx import networkx as nx
import logging
import trio


from api.db import LLMType from api.db import LLMType
from api.db.services.document_service import DocumentService from api.db.services.document_service import DocumentService
from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMBundle from api.db.services.llm_service import LLMBundle
from api.db.services.user_service import TenantService from api.db.services.user_service import TenantService
from graphrag.general.index import Dealer
from graphrag.general.index import update_graph
from graphrag.light.graph_extractor import GraphExtractor from graphrag.light.graph_extractor import GraphExtractor
from rag.utils.redis_conn import RedisDistributedLock


settings.init_settings() settings.init_settings()


if __name__ == "__main__":

def callback(prog=None, msg="Processing..."):
logging.info(msg)


async def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-t', '--tenant_id', default=False, help="Tenant ID", action='store', required=True)
parser.add_argument('-d', '--doc_id', default=False, help="Document ID", action='store', required=True)
parser.add_argument(
"-t",
"--tenant_id",
default=False,
help="Tenant ID",
action="store",
required=True,
)
parser.add_argument(
"-d",
"--doc_id",
default=False,
help="Document ID",
action="store",
required=True,
)
args = parser.parse_args() args = parser.parse_args()


e, doc = DocumentService.get_by_id(args.doc_id) e, doc = DocumentService.get_by_id(args.doc_id)
raise LookupError("Document not found.") raise LookupError("Document not found.")
kb_id = doc.kb_id kb_id = doc.kb_id


chunks = [d["content_with_weight"] for d in
settings.retrievaler.chunk_list(args.doc_id, args.tenant_id, [kb_id], max_count=6,
fields=["content_with_weight"])]
chunks = [("x", c) for c in chunks]

RedisDistributedLock.clean_lock(kb_id)
chunks = [
d["content_with_weight"]
for d in settings.retrievaler.chunk_list(
args.doc_id,
args.tenant_id,
[kb_id],
max_count=6,
fields=["content_with_weight"],
)
]


_, tenant = TenantService.get_by_id(args.tenant_id) _, tenant = TenantService.get_by_id(args.tenant_id)
llm_bdl = LLMBundle(args.tenant_id, LLMType.CHAT, tenant.llm_id) llm_bdl = LLMBundle(args.tenant_id, LLMType.CHAT, tenant.llm_id)
_, kb = KnowledgebaseService.get_by_id(kb_id) _, kb = KnowledgebaseService.get_by_id(kb_id)
embed_bdl = LLMBundle(args.tenant_id, LLMType.EMBEDDING, kb.embd_id) embed_bdl = LLMBundle(args.tenant_id, LLMType.EMBEDDING, kb.embd_id)


dealer = Dealer(GraphExtractor, args.tenant_id, kb_id, llm_bdl, chunks, "English", embed_bdl=embed_bdl)
graph, doc_ids = await update_graph(
GraphExtractor,
args.tenant_id,
kb_id,
args.doc_id,
chunks,
"English",
llm_bdl,
embed_bdl,
callback,
)

print(json.dumps(nx.node_link_data(graph), ensure_ascii=False, indent=2))


print(json.dumps(nx.node_link_data(dealer.graph), ensure_ascii=False, indent=2))

if __name__ == "__main__":
trio.run(main)

+ 56
- 24
graphrag/utils.py View File

chunk["q_%d_vec" % len(ebd)] = ebd chunk["q_%d_vec" % len(ebd)] = ebd
settings.docStoreConn.insert([{"id": chunk_id(chunk), **chunk}], search.index_name(tenant_id), kb_id) settings.docStoreConn.insert([{"id": chunk_id(chunk), **chunk}], search.index_name(tenant_id), kb_id)


async def does_graph_contains(tenant_id, kb_id, doc_id):
# Get doc_ids of graph
fields = ["source_id"]
condition = {
"knowledge_graph_kwd": ["graph"],
"removed_kwd": "N",
}
res = await trio.to_thread.run_sync(lambda: settings.docStoreConn.search(fields, [], condition, [], OrderByExpr(), 0, 1, search.index_name(tenant_id), [kb_id]))
fields2 = settings.docStoreConn.getFields(res, fields)
graph_doc_ids = set()
for chunk_id in fields2.keys():
graph_doc_ids = set(fields2[chunk_id]["source_id"])
return doc_id in graph_doc_ids

async def get_graph_doc_ids(tenant_id, kb_id) -> list[str]:
conds = {
"fields": ["source_id"],
"removed_kwd": "N",
"size": 1,
"knowledge_graph_kwd": ["graph"]
}
res = await trio.to_thread.run_sync(lambda: settings.retrievaler.search(conds, search.index_name(tenant_id), [kb_id]))
doc_ids = []
if res.total == 0:
return doc_ids
for id in res.ids:
doc_ids = res.field[id]["source_id"]
return doc_ids



def get_graph(tenant_id, kb_id):
async def get_graph(tenant_id, kb_id):
conds = { conds = {
"fields": ["content_with_weight", "source_id"], "fields": ["content_with_weight", "source_id"],
"removed_kwd": "N", "removed_kwd": "N",
"size": 1, "size": 1,
"knowledge_graph_kwd": ["graph"] "knowledge_graph_kwd": ["graph"]
} }
res = settings.retrievaler.search(conds, search.index_name(tenant_id), [kb_id])
res = await trio.to_thread.run_sync(lambda: settings.retrievaler.search(conds, search.index_name(tenant_id), [kb_id]))
if res.total == 0:
return None, []
for id in res.ids: for id in res.ids:
try: try:
return json_graph.node_link_graph(json.loads(res.field[id]["content_with_weight"]), edges="edges"), \ return json_graph.node_link_graph(json.loads(res.field[id]["content_with_weight"]), edges="edges"), \
res.field[id]["source_id"] res.field[id]["source_id"]
except Exception: except Exception:
continue continue
return rebuild_graph(tenant_id, kb_id)
result = await rebuild_graph(tenant_id, kb_id)
return result




def set_graph(tenant_id, kb_id, graph, docids):
async def set_graph(tenant_id, kb_id, graph, docids):
chunk = { chunk = {
"content_with_weight": json.dumps(nx.node_link_data(graph, edges="edges"), ensure_ascii=False, "content_with_weight": json.dumps(nx.node_link_data(graph, edges="edges"), ensure_ascii=False,
indent=2), indent=2),
"source_id": list(docids), "source_id": list(docids),
"available_int": 0, "available_int": 0,
"removed_kwd": "N" "removed_kwd": "N"
}
res = settings.retrievaler.search({"knowledge_graph_kwd": "graph", "size": 1, "fields": []}, search.index_name(tenant_id), [kb_id])
}
res = await trio.to_thread.run_sync(lambda: settings.retrievaler.search({"knowledge_graph_kwd": "graph", "size": 1, "fields": []}, search.index_name(tenant_id), [kb_id]))
if res.ids: if res.ids:
settings.docStoreConn.update({"knowledge_graph_kwd": "graph"}, chunk,
search.index_name(tenant_id), kb_id)
await trio.to_thread.run_sync(lambda: settings.docStoreConn.update({"knowledge_graph_kwd": "graph"}, chunk,
search.index_name(tenant_id), kb_id))
else: else:
settings.docStoreConn.insert([{"id": chunk_id(chunk), **chunk}], search.index_name(tenant_id), kb_id)
await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert([{"id": chunk_id(chunk), **chunk}], search.index_name(tenant_id), kb_id))




def is_continuous_subsequence(subseq, seq): def is_continuous_subsequence(subseq, seq):
return result return result




def update_nodes_pagerank_nhop_neighbour(tenant_id, kb_id, graph, n_hop):
async def update_nodes_pagerank_nhop_neighbour(tenant_id, kb_id, graph, n_hop):
def n_neighbor(id): def n_neighbor(id):
nonlocal graph, n_hop nonlocal graph, n_hop
count = 0 count = 0
for n, p in pr.items(): for n, p in pr.items():
graph.nodes[n]["pagerank"] = p graph.nodes[n]["pagerank"] = p
try: try:
settings.docStoreConn.update({"entity_kwd": n, "kb_id": kb_id},
await trio.to_thread.run_sync(lambda: settings.docStoreConn.update({"entity_kwd": n, "kb_id": kb_id},
{"rank_flt": p, {"rank_flt": p,
"n_hop_with_weight": json.dumps(n_neighbor(n), ensure_ascii=False)},
search.index_name(tenant_id), kb_id)
"n_hop_with_weight": json.dumps( (n), ensure_ascii=False)},
search.index_name(tenant_id), kb_id))
except Exception as e: except Exception as e:
logging.exception(e) logging.exception(e)


"knowledge_graph_kwd": "ty2ents", "knowledge_graph_kwd": "ty2ents",
"available_int": 0 "available_int": 0
} }
res = settings.retrievaler.search({"knowledge_graph_kwd": "ty2ents", "size": 1, "fields": []},
search.index_name(tenant_id), [kb_id])
res = await trio.to_thread.run_sync(lambda: settings.retrievaler.search({"knowledge_graph_kwd": "ty2ents", "size": 1, "fields": []},
search.index_name(tenant_id), [kb_id]))
if res.ids: if res.ids:
settings.docStoreConn.update({"knowledge_graph_kwd": "ty2ents"},
await trio.to_thread.run_sync(lambda: settings.docStoreConn.update({"knowledge_graph_kwd": "ty2ents"},
chunk, chunk,
search.index_name(tenant_id), kb_id)
search.index_name(tenant_id), kb_id))
else: else:
settings.docStoreConn.insert([{"id": chunk_id(chunk), **chunk}], search.index_name(tenant_id), kb_id)
await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert([{"id": chunk_id(chunk), **chunk}], search.index_name(tenant_id), kb_id))




def get_entity_type2sampels(idxnms, kb_ids: list):
es_res = settings.retrievaler.search({"knowledge_graph_kwd": "ty2ents", "kb_id": kb_ids,
async def get_entity_type2sampels(idxnms, kb_ids: list):
es_res = await trio.to_thread.run_sync(lambda: settings.retrievaler.search({"knowledge_graph_kwd": "ty2ents", "kb_id": kb_ids,
"size": 10000, "size": 10000,
"fields": ["content_with_weight"]}, "fields": ["content_with_weight"]},
idxnms, kb_ids)
idxnms, kb_ids))


res = defaultdict(list) res = defaultdict(list)
for id in es_res.ids: for id in es_res.ids:
return list(set(res)) return list(set(res))




def rebuild_graph(tenant_id, kb_id):
async def rebuild_graph(tenant_id, kb_id):
graph = nx.Graph() graph = nx.Graph()
src_ids = [] src_ids = []
flds = ["entity_kwd", "entity_type_kwd", "from_entity_kwd", "to_entity_kwd", "weight_int", "knowledge_graph_kwd", "source_id"] flds = ["entity_kwd", "entity_type_kwd", "from_entity_kwd", "to_entity_kwd", "weight_int", "knowledge_graph_kwd", "source_id"]
bs = 256 bs = 256
for i in range(0, 39*bs, bs): for i in range(0, 39*bs, bs):
es_res = settings.docStoreConn.search(flds, [],
es_res = await trio.to_thread.run_sync(lambda: settings.docStoreConn.search(flds, [],
{"kb_id": kb_id, "knowledge_graph_kwd": ["entity", "relation"]}, {"kb_id": kb_id, "knowledge_graph_kwd": ["entity", "relation"]},
[], [],
OrderByExpr(), OrderByExpr(),
i, bs, search.index_name(tenant_id), [kb_id] i, bs, search.index_name(tenant_id), [kb_id]
)
))
tot = settings.docStoreConn.getTotal(es_res) tot = settings.docStoreConn.getTotal(es_res)
if tot == 0: if tot == 0:
return None, None return None, None

+ 65
- 38
rag/raptor.py View File

# #
import logging import logging
import re import re
from threading import Lock
import umap import umap
import numpy as np import numpy as np
from sklearn.mixture import GaussianMixture from sklearn.mixture import GaussianMixture
import trio import trio


from graphrag.utils import get_llm_cache, get_embed_cache, set_embed_cache, set_llm_cache, chat_limiter
from graphrag.utils import (
get_llm_cache,
get_embed_cache,
set_embed_cache,
set_llm_cache,
chat_limiter,
)
from rag.utils import truncate from rag.utils import truncate




class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval: class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
def __init__(self, max_cluster, llm_model, embd_model, prompt, max_token=512, threshold=0.1):
def __init__(
self, max_cluster, llm_model, embd_model, prompt, max_token=512, threshold=0.1
):
self._max_cluster = max_cluster self._max_cluster = max_cluster
self._llm_model = llm_model self._llm_model = llm_model
self._embd_model = embd_model self._embd_model = embd_model
self._prompt = prompt self._prompt = prompt
self._max_token = max_token self._max_token = max_token


def _chat(self, system, history, gen_conf):
async def _chat(self, system, history, gen_conf):
response = get_llm_cache(self._llm_model.llm_name, system, history, gen_conf) response = get_llm_cache(self._llm_model.llm_name, system, history, gen_conf)
if response: if response:
return response return response
response = self._llm_model.chat(system, history, gen_conf)
response = await trio.to_thread.run_sync(
lambda: self._llm_model.chat(system, history, gen_conf)
)
response = re.sub(r"<think>.*</think>", "", response, flags=re.DOTALL) response = re.sub(r"<think>.*</think>", "", response, flags=re.DOTALL)
if response.find("**ERROR**") >= 0: if response.find("**ERROR**") >= 0:
raise Exception(response) raise Exception(response)
set_llm_cache(self._llm_model.llm_name, system, response, history, gen_conf) set_llm_cache(self._llm_model.llm_name, system, response, history, gen_conf)
return response return response


def _embedding_encode(self, txt):
async def _embedding_encode(self, txt):
response = get_embed_cache(self._embd_model.llm_name, txt) response = get_embed_cache(self._embd_model.llm_name, txt)
if response is not None: if response is not None:
return response return response
embds, _ = self._embd_model.encode([txt])
embds, _ = await trio.to_thread.run_sync(lambda: self._embd_model.encode([txt]))
if len(embds) < 1 or len(embds[0]) < 1: if len(embds) < 1 or len(embds[0]) < 1:
raise Exception("Embedding error: ") raise Exception("Embedding error: ")
embds = embds[0] embds = embds[0]
return [] return []
chunks = [(s, a) for s, a in chunks if s and len(a) > 0] chunks = [(s, a) for s, a in chunks if s and len(a) > 0]


async def summarize(ck_idx, lock):
async def summarize(ck_idx: list[int]):
nonlocal chunks nonlocal chunks
try:
texts = [chunks[i][0] for i in ck_idx]
len_per_chunk = int((self._llm_model.max_length - self._max_token) / len(texts))
cluster_content = "\n".join([truncate(t, max(1, len_per_chunk)) for t in texts])
async with chat_limiter:
cnt = await trio.to_thread.run_sync(lambda: self._chat("You're a helpful assistant.",
[{"role": "user",
"content": self._prompt.format(cluster_content=cluster_content)}],
{"temperature": 0.3, "max_tokens": self._max_token}
))
cnt = re.sub("(······\n由于长度的原因,回答被截断了,要继续吗?|For the content length reason, it stopped, continue?)", "",
cnt)
logging.debug(f"SUM: {cnt}")
embds, _ = self._embd_model.encode([cnt])
with lock:
chunks.append((cnt, self._embedding_encode(cnt)))
except Exception as e:
logging.exception("summarize got exception")
return e
texts = [chunks[i][0] for i in ck_idx]
len_per_chunk = int(
(self._llm_model.max_length - self._max_token) / len(texts)
)
cluster_content = "\n".join(
[truncate(t, max(1, len_per_chunk)) for t in texts]
)
async with chat_limiter:
cnt = await self._chat(
"You're a helpful assistant.",
[
{
"role": "user",
"content": self._prompt.format(
cluster_content=cluster_content
),
}
],
{"temperature": 0.3, "max_tokens": self._max_token},
)
cnt = re.sub(
"(······\n由于长度的原因,回答被截断了,要继续吗?|For the content length reason, it stopped, continue?)",
"",
cnt,
)
logging.debug(f"SUM: {cnt}")
embds = await self._embedding_encode(cnt)
chunks.append((cnt, embds))


labels = [] labels = []
lock = Lock()
while end - start > 1: while end - start > 1:
embeddings = [embd for _, embd in chunks[start: end]]
embeddings = [embd for _, embd in chunks[start:end]]
if len(embeddings) == 2: if len(embeddings) == 2:
await summarize([start, start + 1], lock)
await summarize([start, start + 1])
if callback: if callback:
callback(msg="Cluster one layer: {} -> {}".format(end - start, len(chunks) - end))
callback(
msg="Cluster one layer: {} -> {}".format(
end - start, len(chunks) - end
)
)
labels.extend([0, 0]) labels.extend([0, 0])
layers.append((end, len(chunks))) layers.append((end, len(chunks)))
start = end start = end


n_neighbors = int((len(embeddings) - 1) ** 0.8) n_neighbors = int((len(embeddings) - 1) ** 0.8)
reduced_embeddings = umap.UMAP( reduced_embeddings = umap.UMAP(
n_neighbors=max(2, n_neighbors), n_components=min(12, len(embeddings) - 2), metric="cosine"
n_neighbors=max(2, n_neighbors),
n_components=min(12, len(embeddings) - 2),
metric="cosine",
).fit_transform(embeddings) ).fit_transform(embeddings)
n_clusters = self._get_optimal_clusters(reduced_embeddings, random_state) n_clusters = self._get_optimal_clusters(reduced_embeddings, random_state)
if n_clusters == 1: if n_clusters == 1:
async with trio.open_nursery() as nursery: async with trio.open_nursery() as nursery:
for c in range(n_clusters): for c in range(n_clusters):
ck_idx = [i + start for i in range(len(lbls)) if lbls[i] == c] ck_idx = [i + start for i in range(len(lbls)) if lbls[i] == c]
if not ck_idx:
continue
assert len(ck_idx) > 0
async with chat_limiter: async with chat_limiter:
nursery.start_soon(lambda: summarize(ck_idx, lock))
nursery.start_soon(lambda: summarize(ck_idx))


assert len(chunks) - end == n_clusters, "{} vs. {}".format(len(chunks) - end, n_clusters)
assert len(chunks) - end == n_clusters, "{} vs. {}".format(
len(chunks) - end, n_clusters
)
labels.extend(lbls) labels.extend(lbls)
layers.append((end, len(chunks))) layers.append((end, len(chunks)))
if callback: if callback:
callback(msg="Cluster one layer: {} -> {}".format(end - start, len(chunks) - end))
callback(
msg="Cluster one layer: {} -> {}".format(
end - start, len(chunks) - end
)
)
start = end start = end
end = len(chunks) end = len(chunks)


return chunks return chunks


+ 11
- 40
rag/svr/task_executor.py View File

import sys import sys


from api.utils.log_utils import initRootLogger, get_project_base_directory from api.utils.log_utils import initRootLogger, get_project_base_directory
from graphrag.general.index import WithCommunity, WithResolution, Dealer
from graphrag.light.graph_extractor import GraphExtractor as LightKGExt
from graphrag.general.graph_extractor import GraphExtractor as GeneralKGExt
from graphrag.general.index import run_graphrag
from graphrag.utils import get_llm_cache, set_llm_cache, get_tags_from_cache, set_tags_to_cache from graphrag.utils import get_llm_cache, set_llm_cache, get_tags_from_cache, set_tags_to_cache
from rag.prompts import keyword_extraction, question_proposal, content_tagging from rag.prompts import keyword_extraction, question_proposal, content_tagging


import resource import resource
import signal import signal
import trio import trio
import exceptiongroup


import numpy as np import numpy as np
from peewee import DoesNotExist from peewee import DoesNotExist
return res, tk_count return res, tk_count




async def run_graphrag(row, chat_model, language, embedding_model, callback=None):
chunks = []
for d in settings.retrievaler.chunk_list(row["doc_id"], row["tenant_id"], [str(row["kb_id"])],
fields=["content_with_weight", "doc_id"]):
chunks.append((d["doc_id"], d["content_with_weight"]))

dealer = Dealer(LightKGExt if row["parser_config"]["graphrag"]["method"] != 'general' else GeneralKGExt,
row["tenant_id"],
str(row["kb_id"]),
chat_model,
chunks=chunks,
language=language,
entity_types=row["parser_config"]["graphrag"]["entity_types"],
embed_bdl=embedding_model,
callback=callback)
await dealer()


async def do_handle_task(task): async def do_handle_task(task):
task_id = task["id"] task_id = task["id"]
task_from_page = task["from_page"] task_from_page = task["from_page"]
return return
start_ts = timer() start_ts = timer()
chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language) chat_model = LLMBundle(task_tenant_id, LLMType.CHAT, llm_name=task_llm_id, lang=task_language)
await run_graphrag(task, chat_model, task_language, embedding_model, progress_callback)
progress_callback(prog=1.0, msg="Knowledge Graph basic is done ({:.2f}s)".format(timer() - start_ts))
if graphrag_conf.get("resolution", False):
start_ts = timer()
with_res = WithResolution(
task["tenant_id"], str(task["kb_id"]), chat_model, embedding_model,
progress_callback
)
await with_res()
progress_callback(prog=1.0, msg="Knowledge Graph resolution is done ({:.2f}s)".format(timer() - start_ts))
if graphrag_conf.get("community", False):
start_ts = timer()
with_comm = WithCommunity(
task["tenant_id"], str(task["kb_id"]), chat_model, embedding_model,
progress_callback
)
await with_comm()
progress_callback(prog=1.0, msg="Knowledge Graph community is done ({:.2f}s)".format(timer() - start_ts))
with_resolution = graphrag_conf.get("resolution", False)
with_community = graphrag_conf.get("community", False)
await run_graphrag(task, task_language, with_resolution, with_community, chat_model, embedding_model, progress_callback)
progress_callback(prog=1.0, msg="Knowledge Graph done ({:.2f}s)".format(timer() - start_ts))
return return
else: else:
# Standard chunking methods # Standard chunking methods
FAILED_TASKS += 1 FAILED_TASKS += 1
CURRENT_TASKS.pop(task["id"], None) CURRENT_TASKS.pop(task["id"], None)
try: try:
set_progress(task["id"], prog=-1, msg=f"[Exception]: {e}")
err_msg = str(e)
while isinstance(e, exceptiongroup.ExceptionGroup):
e = e.exceptions[0]
err_msg += ' -- ' + str(e)
set_progress(task["id"], prog=-1, msg=f"[Exception]: {err_msg}")
except Exception: except Exception:
pass pass
logging.exception(f"handle_task got exception for task {json.dumps(task)}") logging.exception(f"handle_task got exception for task {json.dumps(task)}")

+ 13
- 20
rag/utils/redis_conn.py View File



import logging import logging
import json import json
import time
import uuid import uuid


import valkey as redis import valkey as redis
from rag import settings from rag import settings
from rag.utils import singleton from rag.utils import singleton
from valkey.lock import Lock


class RedisMsg: class RedisMsg:
def __init__(self, consumer, queue_name, group_name, msg_id, message): def __init__(self, consumer, queue_name, group_name, msg_id, message):




class RedisDistributedLock: class RedisDistributedLock:
def __init__(self, lock_key, timeout=10):
def __init__(self, lock_key, lock_value=None, timeout=10, blocking_timeout=1):
self.lock_key = lock_key self.lock_key = lock_key
self.lock_value = str(uuid.uuid4())
if lock_value:
self.lock_value = lock_value
else:
self.lock_value = str(uuid.uuid4())
self.timeout = timeout self.timeout = timeout
self.lock = Lock(REDIS_CONN.REDIS, lock_key, timeout=timeout, blocking_timeout=blocking_timeout)


@staticmethod
def clean_lock(lock_key):
REDIS_CONN.REDIS.delete(lock_key)

def acquire_lock(self):
end_time = time.time() + self.timeout
while time.time() < end_time:
if REDIS_CONN.REDIS.setnx(self.lock_key, self.lock_value):
return True
time.sleep(1)
return False
def acquire(self):
return self.lock.acquire()


def release_lock(self):
if REDIS_CONN.REDIS.get(self.lock_key) == self.lock_value:
REDIS_CONN.REDIS.delete(self.lock_key)
def release(self):
return self.lock.release()


def __enter__(self): def __enter__(self):
self.acquire_lock()
self.acquire()


def __exit__(self, exception_type, exception_value, exception_traceback): def __exit__(self, exception_type, exception_value, exception_traceback):
self.release_lock()
self.release()

Loading…
Cancel
Save