浏览代码

Fix bugs (#3502)

### What problem does this PR solve?

1. Remove unused code
2. Fix type mismatch, in nlp search and infinity search interface
3. Fix chunk list, get all chunks of this user.

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)

---------

Signed-off-by: jinhai <haijin.chn@gmail.com>
tags/v0.14.0
Jin Hai 11 个月前
父节点
当前提交
2044bb0039
共有 3 个文件被更改,包括 54 次插入46 次删除
  1. 1
    2
      agent/component/base.py
  2. 1
    1
      graphrag/search.py
  3. 52
    43
      rag/utils/infinity_conn.py

+ 1
- 2
agent/component/base.py 查看文件

import builtins import builtins
import json import json
import os import os
import logging
from functools import partial from functools import partial
from typing import Tuple, Union from typing import Tuple, Union


import pandas as pd import pandas as pd


from agent import settings from agent import settings
from agent.settings import flow_logger, DEBUG


_FEEDED_DEPRECATED_PARAMS = "_feeded_deprecated_params" _FEEDED_DEPRECATED_PARAMS = "_feeded_deprecated_params"
_DEPRECATED_PARAMS = "_deprecated_params" _DEPRECATED_PARAMS = "_deprecated_params"


upstream_outs = [] upstream_outs = []


if DEBUG: print(self.component_name, reversed_cpnts[::-1])
for u in reversed_cpnts[::-1]: for u in reversed_cpnts[::-1]:
if self.get_component_name(u) in ["switch", "concentrator"]: continue if self.get_component_name(u) in ["switch", "concentrator"]: continue
if self.component_name.lower() == "generate" and self.get_component_name(u) == "retrieval": if self.component_name.lower() == "generate" and self.get_component_name(u) == "retrieval":

+ 1
- 1
graphrag/search.py 查看文件





class KGSearch(Dealer): class KGSearch(Dealer):
def search(self, req, idxnm, kb_ids, emb_mdl, highlight=False):
def search(self, req, idxnm: str | list[str], kb_ids: list[str], emb_mdl=None, highlight=False):
def merge_into_first(sres, title="") -> dict[str, str]: def merge_into_first(sres, title="") -> dict[str, str]:
if not sres: if not sres:
return {} return {}

+ 52
- 43
rag/utils/infinity_conn.py 查看文件

import json import json
import time import time
import infinity import infinity
from infinity.common import ConflictType, InfinityException
from infinity.common import ConflictType, InfinityException, SortType
from infinity.index import IndexInfo, IndexType from infinity.index import IndexInfo, IndexType
from infinity.connection_pool import ConnectionPool from infinity.connection_pool import ConnectionPool
from rag import settings from rag import settings
OrderByExpr, OrderByExpr,
) )



def equivalent_condition_to_str(condition: dict) -> str: def equivalent_condition_to_str(condition: dict) -> str:
assert "_id" not in condition assert "_id" not in condition
cond = list() cond = list()
self.connPool = connPool self.connPool = connPool
break break
except Exception as e: except Exception as e:
logging.warn(f"{str(e)}. Waiting Infinity {infinity_uri} to be healthy.")
logging.warning(f"{str(e)}. Waiting Infinity {infinity_uri} to be healthy.")
time.sleep(5) time.sleep(5)
if self.connPool is None: if self.connPool is None:
msg = f"Infinity {infinity_uri} didn't become healthy in 120s." msg = f"Infinity {infinity_uri} didn't become healthy in 120s."
self.connPool.release_conn(inf_conn) self.connPool.release_conn(inf_conn)
return True return True
except Exception as e: except Exception as e:
logging.warn(f"INFINITY indexExist {str(e)}")
logging.warning(f"INFINITY indexExist {str(e)}")
return False return False


""" """
""" """


def search( def search(
self,
selectFields: list[str],
highlightFields: list[str],
condition: dict,
matchExprs: list[MatchExpr],
orderBy: OrderByExpr,
offset: int,
limit: int,
indexNames: str|list[str],
knowledgebaseIds: list[str],
self,
selectFields: list[str],
highlightFields: list[str],
condition: dict,
matchExprs: list[MatchExpr],
orderBy: OrderByExpr,
offset: int,
limit: int,
indexNames: str | list[str],
knowledgebaseIds: list[str],
) -> list[dict] | pl.DataFrame: ) -> list[dict] | pl.DataFrame:
""" """
TODO: Infinity doesn't provide highlight TODO: Infinity doesn't provide highlight
minimum_should_match = "0%" minimum_should_match = "0%"
if "minimum_should_match" in matchExpr.extra_options: if "minimum_should_match" in matchExpr.extra_options:
minimum_should_match = ( minimum_should_match = (
str(int(matchExpr.extra_options["minimum_should_match"] * 100))
+ "%"
str(int(matchExpr.extra_options["minimum_should_match"] * 100))
+ "%"
) )
matchExpr.extra_options.update( matchExpr.extra_options.update(
{"minimum_should_match": minimum_should_match} {"minimum_should_match": minimum_should_match}
for k, v in matchExpr.extra_options.items(): for k, v in matchExpr.extra_options.items():
if not isinstance(v, str): if not isinstance(v, str):
matchExpr.extra_options[k] = str(v) matchExpr.extra_options[k] = str(v)

order_by_expr_list = list()
if orderBy.fields: if orderBy.fields:
order_by_expr_list = list()
for order_field in orderBy.fields: for order_field in orderBy.fields:
order_by_expr_list.append((order_field[0], order_field[1] == 0))
if order_field[1] == 0:
order_by_expr_list.append((order_field[0], SortType.Asc))
else:
order_by_expr_list.append((order_field[0], SortType.Desc))


# Scatter search tables and gather the results # Scatter search tables and gather the results
for indexName in indexNames: for indexName in indexNames:
continue continue
table_list.append(table_name) table_list.append(table_name)
builder = table_instance.output(selectFields) builder = table_instance.output(selectFields)
for matchExpr in matchExprs:
if isinstance(matchExpr, MatchTextExpr):
fields = ",".join(matchExpr.fields)
builder = builder.match_text(
fields,
matchExpr.matching_text,
matchExpr.topn,
matchExpr.extra_options,
)
elif isinstance(matchExpr, MatchDenseExpr):
builder = builder.match_dense(
matchExpr.vector_column_name,
matchExpr.embedding_data,
matchExpr.embedding_data_type,
matchExpr.distance_type,
matchExpr.topn,
matchExpr.extra_options,
)
elif isinstance(matchExpr, FusionExpr):
builder = builder.fusion(
matchExpr.method, matchExpr.topn, matchExpr.fusion_params
)
if len(matchExprs) > 0:
for matchExpr in matchExprs:
if isinstance(matchExpr, MatchTextExpr):
fields = ",".join(matchExpr.fields)
builder = builder.match_text(
fields,
matchExpr.matching_text,
matchExpr.topn,
matchExpr.extra_options,
)
elif isinstance(matchExpr, MatchDenseExpr):
builder = builder.match_dense(
matchExpr.vector_column_name,
matchExpr.embedding_data,
matchExpr.embedding_data_type,
matchExpr.distance_type,
matchExpr.topn,
matchExpr.extra_options,
)
elif isinstance(matchExpr, FusionExpr):
builder = builder.fusion(
matchExpr.method, matchExpr.topn, matchExpr.fusion_params
)
else:
if len(filter_cond) > 0:
builder.filter(filter_cond)
if orderBy.fields: if orderBy.fields:
builder.sort(order_by_expr_list) builder.sort(order_by_expr_list)
builder.offset(offset).limit(limit) builder.offset(offset).limit(limit)
return res return res


def get( def get(
self, chunkId: str, indexName: str, knowledgebaseIds: list[str]
self, chunkId: str, indexName: str, knowledgebaseIds: list[str]
) -> dict | None: ) -> dict | None:
inf_conn = self.connPool.get_conn() inf_conn = self.connPool.get_conn()
db_instance = inf_conn.get_database(self.dbName) db_instance = inf_conn.get_database(self.dbName)
return res_fields.get(chunkId, None) return res_fields.get(chunkId, None)


def insert( def insert(
self, documents: list[dict], indexName: str, knowledgebaseId: str
self, documents: list[dict], indexName: str, knowledgebaseId: str
) -> list[str]: ) -> list[str]:
inf_conn = self.connPool.get_conn() inf_conn = self.connPool.get_conn()
db_instance = inf_conn.get_database(self.dbName) db_instance = inf_conn.get_database(self.dbName)
return [] return []


def update( def update(
self, condition: dict, newValue: dict, indexName: str, knowledgebaseId: str
self, condition: dict, newValue: dict, indexName: str, knowledgebaseId: str
) -> bool: ) -> bool:
# if 'position_list' in newValue: # if 'position_list' in newValue:
# logging.info(f"upsert position_list: {newValue['position_list']}") # logging.info(f"upsert position_list: {newValue['position_list']}")
flags=re.IGNORECASE | re.MULTILINE, flags=re.IGNORECASE | re.MULTILINE,
) )
if not re.search( if not re.search(
r"<em>[^<>]+</em>", t, flags=re.IGNORECASE | re.MULTILINE
r"<em>[^<>]+</em>", t, flags=re.IGNORECASE | re.MULTILINE
): ):
continue continue
txts.append(t) txts.append(t)

正在加载...
取消
保存