Procházet zdrojové kódy

Fix bugs (#3502)

### What problem does this PR solve?

1. Remove unused code
2. Fix type mismatch, in nlp search and infinity search interface
3. Fix chunk list, get all chunks of this user.

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)

---------

Signed-off-by: jinhai <haijin.chn@gmail.com>
tags/v0.14.0
Jin Hai před 11 měsíci
rodič
revize
2044bb0039
3 změnil soubory, kde provedl 54 přidání a 46 odebrání
  1. 1
    2
      agent/component/base.py
  2. 1
    1
      graphrag/search.py
  3. 52
    43
      rag/utils/infinity_conn.py

+ 1
- 2
agent/component/base.py Zobrazit soubor

@@ -17,13 +17,13 @@ from abc import ABC
import builtins
import json
import os
import logging
from functools import partial
from typing import Tuple, Union

import pandas as pd

from agent import settings
from agent.settings import flow_logger, DEBUG

_FEEDED_DEPRECATED_PARAMS = "_feeded_deprecated_params"
_DEPRECATED_PARAMS = "_deprecated_params"
@@ -480,7 +480,6 @@ class ComponentBase(ABC):

upstream_outs = []

if DEBUG: print(self.component_name, reversed_cpnts[::-1])
for u in reversed_cpnts[::-1]:
if self.get_component_name(u) in ["switch", "concentrator"]: continue
if self.component_name.lower() == "generate" and self.get_component_name(u) == "retrieval":

+ 1
- 1
graphrag/search.py Zobrazit soubor

@@ -23,7 +23,7 @@ from rag.nlp.search import Dealer


class KGSearch(Dealer):
def search(self, req, idxnm, kb_ids, emb_mdl, highlight=False):
def search(self, req, idxnm: str | list[str], kb_ids: list[str], emb_mdl=None, highlight=False):
def merge_into_first(sres, title="") -> dict[str, str]:
if not sres:
return {}

+ 52
- 43
rag/utils/infinity_conn.py Zobrazit soubor

@@ -4,7 +4,7 @@ import re
import json
import time
import infinity
from infinity.common import ConflictType, InfinityException
from infinity.common import ConflictType, InfinityException, SortType
from infinity.index import IndexInfo, IndexType
from infinity.connection_pool import ConnectionPool
from rag import settings
@@ -22,6 +22,7 @@ from rag.utils.doc_store_conn import (
OrderByExpr,
)


def equivalent_condition_to_str(condition: dict) -> str:
assert "_id" not in condition
cond = list()
@@ -65,7 +66,7 @@ class InfinityConnection(DocStoreConnection):
self.connPool = connPool
break
except Exception as e:
logging.warn(f"{str(e)}. Waiting Infinity {infinity_uri} to be healthy.")
logging.warning(f"{str(e)}. Waiting Infinity {infinity_uri} to be healthy.")
time.sleep(5)
if self.connPool is None:
msg = f"Infinity {infinity_uri} didn't become healthy in 120s."
@@ -168,7 +169,7 @@ class InfinityConnection(DocStoreConnection):
self.connPool.release_conn(inf_conn)
return True
except Exception as e:
logging.warn(f"INFINITY indexExist {str(e)}")
logging.warning(f"INFINITY indexExist {str(e)}")
return False

"""
@@ -176,16 +177,16 @@ class InfinityConnection(DocStoreConnection):
"""

def search(
self,
selectFields: list[str],
highlightFields: list[str],
condition: dict,
matchExprs: list[MatchExpr],
orderBy: OrderByExpr,
offset: int,
limit: int,
indexNames: str|list[str],
knowledgebaseIds: list[str],
self,
selectFields: list[str],
highlightFields: list[str],
condition: dict,
matchExprs: list[MatchExpr],
orderBy: OrderByExpr,
offset: int,
limit: int,
indexNames: str | list[str],
knowledgebaseIds: list[str],
) -> list[dict] | pl.DataFrame:
"""
TODO: Infinity doesn't provide highlight
@@ -219,8 +220,8 @@ class InfinityConnection(DocStoreConnection):
minimum_should_match = "0%"
if "minimum_should_match" in matchExpr.extra_options:
minimum_should_match = (
str(int(matchExpr.extra_options["minimum_should_match"] * 100))
+ "%"
str(int(matchExpr.extra_options["minimum_should_match"] * 100))
+ "%"
)
matchExpr.extra_options.update(
{"minimum_should_match": minimum_should_match}
@@ -234,10 +235,14 @@ class InfinityConnection(DocStoreConnection):
for k, v in matchExpr.extra_options.items():
if not isinstance(v, str):
matchExpr.extra_options[k] = str(v)

order_by_expr_list = list()
if orderBy.fields:
order_by_expr_list = list()
for order_field in orderBy.fields:
order_by_expr_list.append((order_field[0], order_field[1] == 0))
if order_field[1] == 0:
order_by_expr_list.append((order_field[0], SortType.Asc))
else:
order_by_expr_list.append((order_field[0], SortType.Desc))

# Scatter search tables and gather the results
for indexName in indexNames:
@@ -249,28 +254,32 @@ class InfinityConnection(DocStoreConnection):
continue
table_list.append(table_name)
builder = table_instance.output(selectFields)
for matchExpr in matchExprs:
if isinstance(matchExpr, MatchTextExpr):
fields = ",".join(matchExpr.fields)
builder = builder.match_text(
fields,
matchExpr.matching_text,
matchExpr.topn,
matchExpr.extra_options,
)
elif isinstance(matchExpr, MatchDenseExpr):
builder = builder.match_dense(
matchExpr.vector_column_name,
matchExpr.embedding_data,
matchExpr.embedding_data_type,
matchExpr.distance_type,
matchExpr.topn,
matchExpr.extra_options,
)
elif isinstance(matchExpr, FusionExpr):
builder = builder.fusion(
matchExpr.method, matchExpr.topn, matchExpr.fusion_params
)
if len(matchExprs) > 0:
for matchExpr in matchExprs:
if isinstance(matchExpr, MatchTextExpr):
fields = ",".join(matchExpr.fields)
builder = builder.match_text(
fields,
matchExpr.matching_text,
matchExpr.topn,
matchExpr.extra_options,
)
elif isinstance(matchExpr, MatchDenseExpr):
builder = builder.match_dense(
matchExpr.vector_column_name,
matchExpr.embedding_data,
matchExpr.embedding_data_type,
matchExpr.distance_type,
matchExpr.topn,
matchExpr.extra_options,
)
elif isinstance(matchExpr, FusionExpr):
builder = builder.fusion(
matchExpr.method, matchExpr.topn, matchExpr.fusion_params
)
else:
if len(filter_cond) > 0:
builder.filter(filter_cond)
if orderBy.fields:
builder.sort(order_by_expr_list)
builder.offset(offset).limit(limit)
@@ -282,7 +291,7 @@ class InfinityConnection(DocStoreConnection):
return res

def get(
self, chunkId: str, indexName: str, knowledgebaseIds: list[str]
self, chunkId: str, indexName: str, knowledgebaseIds: list[str]
) -> dict | None:
inf_conn = self.connPool.get_conn()
db_instance = inf_conn.get_database(self.dbName)
@@ -299,7 +308,7 @@ class InfinityConnection(DocStoreConnection):
return res_fields.get(chunkId, None)

def insert(
self, documents: list[dict], indexName: str, knowledgebaseId: str
self, documents: list[dict], indexName: str, knowledgebaseId: str
) -> list[str]:
inf_conn = self.connPool.get_conn()
db_instance = inf_conn.get_database(self.dbName)
@@ -341,7 +350,7 @@ class InfinityConnection(DocStoreConnection):
return []

def update(
self, condition: dict, newValue: dict, indexName: str, knowledgebaseId: str
self, condition: dict, newValue: dict, indexName: str, knowledgebaseId: str
) -> bool:
# if 'position_list' in newValue:
# logging.info(f"upsert position_list: {newValue['position_list']}")
@@ -430,7 +439,7 @@ class InfinityConnection(DocStoreConnection):
flags=re.IGNORECASE | re.MULTILINE,
)
if not re.search(
r"<em>[^<>]+</em>", t, flags=re.IGNORECASE | re.MULTILINE
r"<em>[^<>]+</em>", t, flags=re.IGNORECASE | re.MULTILINE
):
continue
txts.append(t)

Načítá se…
Zrušit
Uložit