Pārlūkot izejas kodu

Perf: limit embedding in KG. (#8917)

### What problem does this PR solve?


### Type of change

- [x] Performance Improvement
tags/v0.20.0
Kevin Hu pirms 3 mēnešiem
vecāks
revīzija
ab53a73768
Revīzijas autora e-pasta adrese nav piesaistīta nevienam kontam
3 mainītis faili ar 39 papildinājumiem un 24 dzēšanām
  1. 5
    3
      api/utils/api_utils.py
  2. 6
    3
      graphrag/entity_resolution.py
  3. 28
    18
      graphrag/utils.py

+ 5
- 3
api/utils/api_utils.py Parādīt failu

@timeout(30, 2) @timeout(30, 2)
async def _is_strong_enough(): async def _is_strong_enough():
nonlocal chat_model, embedding_model nonlocal chat_model, embedding_model
_ = await trio.to_thread.run_sync(lambda: embedding_model.encode(["Are you strong enough!?"]))
res = await trio.to_thread.run_sync(lambda: chat_model.chat("Nothing special.", [{"role":"user", "content": "Are you strong enough!?"}], {}))
with trio.fail_after(3):
_ = await trio.to_thread.run_sync(lambda: embedding_model.encode(["Are you strong enough!?"]))
with trio.fail_after(30):
res = await trio.to_thread.run_sync(lambda: chat_model.chat("Nothing special.", [{"role":"user", "content": "Are you strong enough!?"}], {}))
if res.find("**ERROR**") >= 0: if res.find("**ERROR**") >= 0:
raise Exception(res) raise Exception(res)


# Pressure test for GraphRAG task # Pressure test for GraphRAG task
async with trio.open_nursery() as nursery: async with trio.open_nursery() as nursery:
for _ in range(12):
for _ in range(32):
nursery.start_soon(_is_strong_enough) nursery.start_soon(_is_strong_enough)

+ 6
- 3
graphrag/entity_resolution.py Parādīt failu

return True return True
return False return False


if len(set(a) & set(b)) > 1:
return True
a, b = set(a), set(b)
max_l = max(len(a), len(b))
if max_l < 4:
return len(a & b) > 1

return len(a & b)*1./max_l >= 0.8


return False

+ 28
- 18
graphrag/utils.py Parādīt failu

import xxhash import xxhash
from networkx.readwrite import json_graph from networkx.readwrite import json_graph
import dataclasses import dataclasses

from api.utils.api_utils import timeout from api.utils.api_utils import timeout
from api import settings from api import settings
from api.utils import get_uuid from api.utils import get_uuid
return xxhash.xxh64((chunk["content_with_weight"] + chunk["kb_id"]).encode("utf-8")).hexdigest() return xxhash.xxh64((chunk["content_with_weight"] + chunk["kb_id"]).encode("utf-8")).hexdigest()




@timeout(3, 3)
async def graph_node_to_chunk(kb_id, embd_mdl, ent_name, meta, chunks): async def graph_node_to_chunk(kb_id, embd_mdl, ent_name, meta, chunks):
global chat_limiter
chunk = { chunk = {
"id": get_uuid(), "id": get_uuid(),
"important_kwd": [ent_name], "important_kwd": [ent_name],
chunk["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(chunk["content_ltks"]) chunk["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(chunk["content_ltks"])
ebd = get_embed_cache(embd_mdl.llm_name, ent_name) ebd = get_embed_cache(embd_mdl.llm_name, ent_name)
if ebd is None: if ebd is None:
ebd, _ = await trio.to_thread.run_sync(lambda: embd_mdl.encode([ent_name]))
async with chat_limiter:
with trio.fail_after(3):
ebd, _ = await trio.to_thread.run_sync(lambda: embd_mdl.encode([ent_name]))
ebd = ebd[0] ebd = ebd[0]
set_embed_cache(embd_mdl.llm_name, ent_name, ebd) set_embed_cache(embd_mdl.llm_name, ent_name, ebd)
assert ebd is not None assert ebd is not None
return res return res




@timeout(3, 3)
async def graph_edge_to_chunk(kb_id, embd_mdl, from_ent_name, to_ent_name, meta, chunks): async def graph_edge_to_chunk(kb_id, embd_mdl, from_ent_name, to_ent_name, meta, chunks):
chunk = { chunk = {
"id": get_uuid(), "id": get_uuid(),
txt = f"{from_ent_name}->{to_ent_name}" txt = f"{from_ent_name}->{to_ent_name}"
ebd = get_embed_cache(embd_mdl.llm_name, txt) ebd = get_embed_cache(embd_mdl.llm_name, txt)
if ebd is None: if ebd is None:
ebd, _ = await trio.to_thread.run_sync(lambda: embd_mdl.encode([txt+f": {meta['description']}"]))
async with chat_limiter:
with trio.fail_after(3):
ebd, _ = await trio.to_thread.run_sync(lambda: embd_mdl.encode([txt+f": {meta['description']}"]))
ebd = ebd[0] ebd = ebd[0]
set_embed_cache(embd_mdl.llm_name, txt, ebd) set_embed_cache(embd_mdl.llm_name, txt, ebd)
assert ebd is not None assert ebd is not None




async def set_graph(tenant_id: str, kb_id: str, embd_mdl, graph: nx.Graph, change: GraphChange, callback): async def set_graph(tenant_id: str, kb_id: str, embd_mdl, graph: nx.Graph, change: GraphChange, callback):
global chat_limiter
start = trio.current_time() start = trio.current_time()


await trio.to_thread.run_sync(lambda: settings.docStoreConn.delete({"knowledge_graph_kwd": ["graph", "subgraph"]}, search.index_name(tenant_id), kb_id)) await trio.to_thread.run_sync(lambda: settings.docStoreConn.delete({"knowledge_graph_kwd": ["graph", "subgraph"]}, search.index_name(tenant_id), kb_id))
if change.removed_nodes: if change.removed_nodes:
await trio.to_thread.run_sync(lambda: settings.docStoreConn.delete({"knowledge_graph_kwd": ["entity"], "entity_kwd": sorted(change.removed_nodes)}, search.index_name(tenant_id), kb_id)) await trio.to_thread.run_sync(lambda: settings.docStoreConn.delete({"knowledge_graph_kwd": ["entity"], "entity_kwd": sorted(change.removed_nodes)}, search.index_name(tenant_id), kb_id))



if change.removed_edges: if change.removed_edges:
async def del_edges(from_node, to_node):
async with chat_limiter:
await trio.to_thread.run_sync(lambda: settings.docStoreConn.delete({"knowledge_graph_kwd": ["relation"], "from_entity_kwd": from_node, "to_entity_kwd": to_node}, search.index_name(tenant_id), kb_id))
async with trio.open_nursery() as nursery: async with trio.open_nursery() as nursery:
for from_node, to_node in change.removed_edges: for from_node, to_node in change.removed_edges:
nursery.start_soon(lambda from_node=from_node, to_node=to_node: trio.to_thread.run_sync(lambda: settings.docStoreConn.delete({"knowledge_graph_kwd": ["relation"], "from_entity_kwd": from_node, "to_entity_kwd": to_node}, search.index_name(tenant_id), kb_id)))
nursery.start_soon(del_edges, from_node, to_node)

now = trio.current_time() now = trio.current_time()
if callback: if callback:
callback(msg=f"set_graph removed {len(change.removed_nodes)} nodes and {len(change.removed_edges)} edges from index in {now - start:.2f}s.") callback(msg=f"set_graph removed {len(change.removed_nodes)} nodes and {len(change.removed_edges)} edges from index in {now - start:.2f}s.")
"removed_kwd": "N" "removed_kwd": "N"
}) })


semaphore = trio.Semaphore(5)
async with trio.open_nursery() as nursery: async with trio.open_nursery() as nursery:
for ii, node in enumerate(change.added_updated_nodes): for ii, node in enumerate(change.added_updated_nodes):
node_attrs = graph.nodes[node] node_attrs = graph.nodes[node]
async with semaphore:
if ii%100 == 9 and callback:
callback(msg=f"Get embedding of nodes: {ii}/{len(change.added_updated_nodes)}")
nursery.start_soon(graph_node_to_chunk, kb_id, embd_mdl, node, node_attrs, chunks)
nursery.start_soon(graph_node_to_chunk, kb_id, embd_mdl, node, node_attrs, chunks)
if ii%100 == 9 and callback:
callback(msg=f"Get embedding of nodes: {ii}/{len(change.added_updated_nodes)}")

async with trio.open_nursery() as nursery:
for ii, (from_node, to_node) in enumerate(change.added_updated_edges): for ii, (from_node, to_node) in enumerate(change.added_updated_edges):
edge_attrs = graph.get_edge_data(from_node, to_node) edge_attrs = graph.get_edge_data(from_node, to_node)
if not edge_attrs: if not edge_attrs:
# added_updated_edges could record a non-existing edge if both from_node and to_node participate in nodes merging. # added_updated_edges could record a non-existing edge if both from_node and to_node participate in nodes merging.
continue continue
async with semaphore:
if ii%100 == 9 and callback:
callback(msg=f"Get embedding of edges: {ii}/{len(change.added_updated_edges)}")
nursery.start_soon(graph_edge_to_chunk, kb_id, embd_mdl, from_node, to_node, edge_attrs, chunks)
nursery.start_soon(graph_edge_to_chunk, kb_id, embd_mdl, from_node, to_node, edge_attrs, chunks)
if ii%100 == 9 and callback:
callback(msg=f"Get embedding of edges: {ii}/{len(change.added_updated_edges)}")
now = trio.current_time() now = trio.current_time()
if callback: if callback:
callback(msg=f"set_graph converted graph change to {len(chunks)} chunks in {now - start:.2f}s.") callback(msg=f"set_graph converted graph change to {len(chunks)} chunks in {now - start:.2f}s.")


es_bulk_size = 4 es_bulk_size = 4
for b in range(0, len(chunks), es_bulk_size): for b in range(0, len(chunks), es_bulk_size):
async with semaphore:
if b % 100 == es_bulk_size and callback:
callback(msg=f"Insert chunks: {b}/{len(chunks)}")
doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert(chunks[b:b + es_bulk_size], search.index_name(tenant_id), kb_id))
with trio.fail_after(3):
doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert(chunks[b:b + es_bulk_size], search.index_name(tenant_id), kb_id))
if b % 100 == es_bulk_size and callback:
callback(msg=f"Insert chunks: {b}/{len(chunks)}")
if doc_store_result: if doc_store_result:
error_message = f"Insert chunk error: {doc_store_result}, please check log file and Elasticsearch/Infinity status!" error_message = f"Insert chunk error: {doc_store_result}, please check log file and Elasticsearch/Infinity status!"
raise Exception(error_message) raise Exception(error_message)

Notiek ielāde…
Atcelt
Saglabāt