### What problem does this PR solve? ### Type of change - [x] Refactoringtags/v0.20.0
| @@ -55,12 +55,18 @@ class Extractor: | |||
| if response: | |||
| return response | |||
| _, system_msg = message_fit_in([{"role": "system", "content": system}], int(self._llm.max_length * 0.92)) | |||
| response = self._llm.chat(system_msg[0]["content"], hist, conf) | |||
| response = re.sub(r"^.*</think>", "", response, flags=re.DOTALL) | |||
| if response.find("**ERROR**") >= 0: | |||
| logging.warning(f"Extractor._chat got error. response: {response}") | |||
| return "" | |||
| set_llm_cache(self._llm.llm_name, system, response, history, gen_conf) | |||
| for attempt in range(3): | |||
| try: | |||
| response = self._llm.chat(system_msg[0]["content"], hist, conf) | |||
| response = re.sub(r"^.*</think>", "", response, flags=re.DOTALL) | |||
| if response.find("**ERROR**") >= 0: | |||
| raise Exception(response) | |||
| set_llm_cache(self._llm.llm_name, system, response, history, gen_conf) | |||
| except Exception as e: | |||
| logging.exception(e) | |||
| if attempt == 2: | |||
| raise | |||
| return response | |||
| def _entities_and_relations(self, chunk_key: str, records: list, tuple_delimiter: str): | |||
| @@ -39,6 +39,14 @@ from rag.nlp import rag_tokenizer, search | |||
| from rag.utils.redis_conn import RedisDistributedLock | |||
| @timeout(30, 2) | |||
| async def _is_strong_enough(chat_model, embedding_model): | |||
| _ = await trio.to_thread.run_sync(lambda: embedding_model.encode(["Are you strong enough!?"])) | |||
| res = await trio.to_thread.run_sync(lambda: chat_model.chat("Nothing special.", [{"role":"user", "content": "Are you strong enough!?"}])) | |||
| if res.find("**ERROR**") >= 0: | |||
| raise Exception(res) | |||
| async def run_graphrag( | |||
| row: dict, | |||
| language, | |||
| @@ -48,6 +56,11 @@ async def run_graphrag( | |||
| embedding_model, | |||
| callback, | |||
| ): | |||
| # Pressure test for GraphRAG task | |||
| async with trio.open_nursery() as nursery: | |||
| for _ in range(12): | |||
| nursery.start_soon(_is_strong_enough, chat_model, embedding_model) | |||
| start = trio.current_time() | |||
| tenant_id, kb_id, doc_id = row["tenant_id"], str(row["kb_id"]), row["doc_id"] | |||
| chunks = [] | |||
| @@ -65,7 +78,7 @@ async def run_graphrag( | |||
| doc_id, | |||
| chunks, | |||
| language, | |||
| row["kb_parser_config"]["graphrag"]["entity_types"], | |||
| row["kb_parser_config"]["graphrag"].get("entity_types", []), | |||
| chat_model, | |||
| embedding_model, | |||
| callback, | |||
| @@ -484,16 +484,20 @@ async def set_graph(tenant_id: str, kb_id: str, embd_mdl, graph: nx.Graph, chang | |||
| semaphore = trio.Semaphore(5) | |||
| async with trio.open_nursery() as nursery: | |||
| for node in change.added_updated_nodes: | |||
| for ii, node in enumerate(change.added_updated_nodes): | |||
| node_attrs = graph.nodes[node] | |||
| async with semaphore: | |||
| if ii%100 == 9 and callback: | |||
| callback(msg=f"Get embedding of nodes: {ii}/{len(change.added_updated_nodes)}") | |||
| nursery.start_soon(graph_node_to_chunk, kb_id, embd_mdl, node, node_attrs, chunks) | |||
| for from_node, to_node in change.added_updated_edges: | |||
| for ii, (from_node, to_node) in enumerate(change.added_updated_edges): | |||
| edge_attrs = graph.get_edge_data(from_node, to_node) | |||
| if not edge_attrs: | |||
| # added_updated_edges could record a non-existing edge if both from_node and to_node participate in nodes merging. | |||
| continue | |||
| async with semaphore: | |||
| if ii%100 == 9 and callback: | |||
| callback(msg=f"Get embedding of edges: {ii}/{len(change.added_updated_edges)}") | |||
| nursery.start_soon(graph_edge_to_chunk, kb_id, embd_mdl, from_node, to_node, edge_attrs, chunks) | |||
| now = trio.current_time() | |||
| if callback: | |||
| @@ -502,6 +506,9 @@ async def set_graph(tenant_id: str, kb_id: str, embd_mdl, graph: nx.Graph, chang | |||
| es_bulk_size = 4 | |||
| for b in range(0, len(chunks), es_bulk_size): | |||
| async with semaphore: | |||
| if b % 100 == es_bulk_size and callback: | |||
| callback(msg=f"Insert chunks: {b}/{len(chunks)}") | |||
| doc_store_result = await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert(chunks[b:b + es_bulk_size], search.index_name(tenant_id), kb_id)) | |||
| if doc_store_result: | |||
| error_message = f"Insert chunk error: {doc_store_result}, please check log file and Elasticsearch/Infinity status!" | |||