|
|
|
@@ -28,8 +28,7 @@ from infinity.errors import ErrorCode |
|
|
|
from rag import settings |
|
|
|
from rag.settings import PAGERANK_FLD |
|
|
|
from rag.utils import singleton |
|
|
|
import polars as pl |
|
|
|
from polars.series.series import Series |
|
|
|
import pandas as pd |
|
|
|
from api.utils.file_utils import get_project_base_directory |
|
|
|
|
|
|
|
from rag.utils.doc_store_conn import ( |
|
|
|
@@ -90,20 +89,20 @@ def equivalent_condition_to_str(condition: dict, table_instance=None) -> str | N |
|
|
|
return " AND ".join(cond) if cond else "1=1" |
|
|
|
|
|
|
|
|
|
|
|
def concat_dataframes(df_list: list[pl.DataFrame], selectFields: list[str]) -> pl.DataFrame: |
|
|
|
""" |
|
|
|
Concatenate multiple dataframes into one. |
|
|
|
""" |
|
|
|
df_list = [df for df in df_list if not df.is_empty()] |
|
|
|
if df_list: |
|
|
|
return pl.concat(df_list) |
|
|
|
schema = dict() |
|
|
|
def concat_dataframes(df_list: list[pd.DataFrame], selectFields: list[str]) -> pd.DataFrame: |
|
|
|
df_list2 = [df for df in df_list if not df.empty] |
|
|
|
if df_list2: |
|
|
|
return pd.concat(df_list2, axis=0).reset_index(drop=True) |
|
|
|
|
|
|
|
schema = [] |
|
|
|
for field_name in selectFields: |
|
|
|
if field_name == 'score()': # Workaround: fix schema is changed to score() |
|
|
|
schema['SCORE'] = str |
|
|
|
if field_name == 'score()': # Workaround: fix schema is changed to score() |
|
|
|
schema.append('SCORE') |
|
|
|
elif field_name == 'similarity()': # Workaround: fix schema is changed to similarity() |
|
|
|
schema.append('SIMILARITY') |
|
|
|
else: |
|
|
|
schema[field_name] = str |
|
|
|
return pl.DataFrame(schema=schema) |
|
|
|
schema.append(field_name) |
|
|
|
return pd.DataFrame(columns=schema) |
|
|
|
|
|
|
|
|
|
|
|
@singleton |
|
|
|
@@ -121,7 +120,7 @@ class InfinityConnection(DocStoreConnection): |
|
|
|
connPool = ConnectionPool(infinity_uri) |
|
|
|
inf_conn = connPool.get_conn() |
|
|
|
res = inf_conn.show_current_node() |
|
|
|
if res.error_code == ErrorCode.OK and res.server_status == "started": |
|
|
|
if res.error_code == ErrorCode.OK and res.server_status in ["started", "alive"]: |
|
|
|
self._migrate_db(inf_conn) |
|
|
|
self.connPool = connPool |
|
|
|
connPool.release_conn(inf_conn) |
|
|
|
@@ -189,7 +188,7 @@ class InfinityConnection(DocStoreConnection): |
|
|
|
self.connPool.release_conn(inf_conn) |
|
|
|
res2 = { |
|
|
|
"type": "infinity", |
|
|
|
"status": "green" if res.error_code == 0 and res.server_status == "started" else "red", |
|
|
|
"status": "green" if res.error_code == 0 and res.server_status in ["started", "alive"] else "red", |
|
|
|
"error": res.error_msg, |
|
|
|
} |
|
|
|
return res2 |
|
|
|
@@ -281,7 +280,7 @@ class InfinityConnection(DocStoreConnection): |
|
|
|
knowledgebaseIds: list[str], |
|
|
|
aggFields: list[str] = [], |
|
|
|
rank_feature: dict | None = None |
|
|
|
) -> list[dict] | pl.DataFrame: |
|
|
|
) -> tuple[pd.DataFrame, int]: |
|
|
|
""" |
|
|
|
TODO: Infinity doesn't provide highlight |
|
|
|
""" |
|
|
|
@@ -292,9 +291,10 @@ class InfinityConnection(DocStoreConnection): |
|
|
|
db_instance = inf_conn.get_database(self.dbName) |
|
|
|
df_list = list() |
|
|
|
table_list = list() |
|
|
|
output = selectFields.copy() |
|
|
|
for essential_field in ["id"]: |
|
|
|
if essential_field not in selectFields: |
|
|
|
selectFields.append(essential_field) |
|
|
|
if essential_field not in output: |
|
|
|
output.append(essential_field) |
|
|
|
score_func = "" |
|
|
|
score_column = "" |
|
|
|
for matchExpr in matchExprs: |
|
|
|
@@ -309,9 +309,11 @@ class InfinityConnection(DocStoreConnection): |
|
|
|
score_column = "SIMILARITY" |
|
|
|
break |
|
|
|
if matchExprs: |
|
|
|
selectFields.append(score_func) |
|
|
|
selectFields.append(PAGERANK_FLD) |
|
|
|
selectFields = [f for f in selectFields if f != "_score"] |
|
|
|
if score_func not in output: |
|
|
|
output.append(score_func) |
|
|
|
if PAGERANK_FLD not in output: |
|
|
|
output.append(PAGERANK_FLD) |
|
|
|
output = [f for f in output if f != "_score"] |
|
|
|
|
|
|
|
# Prepare expressions common to all tables |
|
|
|
filter_cond = None |
|
|
|
@@ -339,7 +341,7 @@ class InfinityConnection(DocStoreConnection): |
|
|
|
matchExpr.extra_options[k] = str(v) |
|
|
|
logger.debug(f"INFINITY search MatchTextExpr: {json.dumps(matchExpr.__dict__)}") |
|
|
|
elif isinstance(matchExpr, MatchDenseExpr): |
|
|
|
if filter_fulltext and filter_cond and "filter" not in matchExpr.extra_options: |
|
|
|
if filter_fulltext and "filter" not in matchExpr.extra_options: |
|
|
|
matchExpr.extra_options.update({"filter": filter_fulltext}) |
|
|
|
for k, v in matchExpr.extra_options.items(): |
|
|
|
if not isinstance(v, str): |
|
|
|
@@ -370,7 +372,7 @@ class InfinityConnection(DocStoreConnection): |
|
|
|
except Exception: |
|
|
|
continue |
|
|
|
table_list.append(table_name) |
|
|
|
builder = table_instance.output(selectFields) |
|
|
|
builder = table_instance.output(output) |
|
|
|
if len(matchExprs) > 0: |
|
|
|
for matchExpr in matchExprs: |
|
|
|
if isinstance(matchExpr, MatchTextExpr): |
|
|
|
@@ -379,7 +381,7 @@ class InfinityConnection(DocStoreConnection): |
|
|
|
fields, |
|
|
|
matchExpr.matching_text, |
|
|
|
matchExpr.topn, |
|
|
|
matchExpr.extra_options, |
|
|
|
matchExpr.extra_options.copy(), |
|
|
|
) |
|
|
|
elif isinstance(matchExpr, MatchDenseExpr): |
|
|
|
builder = builder.match_dense( |
|
|
|
@@ -388,7 +390,7 @@ class InfinityConnection(DocStoreConnection): |
|
|
|
matchExpr.embedding_data_type, |
|
|
|
matchExpr.distance_type, |
|
|
|
matchExpr.topn, |
|
|
|
matchExpr.extra_options, |
|
|
|
matchExpr.extra_options.copy(), |
|
|
|
) |
|
|
|
elif isinstance(matchExpr, FusionExpr): |
|
|
|
builder = builder.fusion( |
|
|
|
@@ -400,18 +402,17 @@ class InfinityConnection(DocStoreConnection): |
|
|
|
if orderBy.fields: |
|
|
|
builder.sort(order_by_expr_list) |
|
|
|
builder.offset(offset).limit(limit) |
|
|
|
kb_res, extra_result = builder.option({"total_hits_count": True}).to_pl() |
|
|
|
kb_res, extra_result = builder.option({"total_hits_count": True}).to_df() |
|
|
|
if extra_result: |
|
|
|
total_hits_count += int(extra_result["total_hits_count"]) |
|
|
|
logger.debug(f"INFINITY search table: {str(table_name)}, result: {str(kb_res)}") |
|
|
|
df_list.append(kb_res) |
|
|
|
self.connPool.release_conn(inf_conn) |
|
|
|
res = concat_dataframes(df_list, selectFields) |
|
|
|
res = concat_dataframes(df_list, output) |
|
|
|
if matchExprs: |
|
|
|
res = res.sort(pl.col(score_column) + pl.col(PAGERANK_FLD), descending=True, maintain_order=True) |
|
|
|
if score_column and score_column != "SCORE": |
|
|
|
res = res.rename({score_column: "_score"}) |
|
|
|
res = res.limit(limit) |
|
|
|
res['Sum'] = res[score_column] + res[PAGERANK_FLD] |
|
|
|
res = res.sort_values(by='Sum', ascending=False).reset_index(drop=True).drop(columns=['Sum']) |
|
|
|
res = res.head(limit) |
|
|
|
logger.debug(f"INFINITY search final result: {str(res)}") |
|
|
|
return res, total_hits_count |
|
|
|
|
|
|
|
@@ -433,12 +434,12 @@ class InfinityConnection(DocStoreConnection): |
|
|
|
logger.warning( |
|
|
|
f"Table not found: {table_name}, this knowledge base isn't created in Infinity. Maybe it is created in other document engine.") |
|
|
|
continue |
|
|
|
kb_res, _ = table_instance.output(["*"]).filter(f"id = '{chunkId}'").to_pl() |
|
|
|
kb_res, _ = table_instance.output(["*"]).filter(f"id = '{chunkId}'").to_df() |
|
|
|
logger.debug(f"INFINITY get table: {str(table_list)}, result: {str(kb_res)}") |
|
|
|
df_list.append(kb_res) |
|
|
|
self.connPool.release_conn(inf_conn) |
|
|
|
res = concat_dataframes(df_list, ["id"]) |
|
|
|
res_fields = self.getFields(res, res.columns) |
|
|
|
res_fields = self.getFields(res, res.columns.tolist()) |
|
|
|
return res_fields.get(chunkId, None) |
|
|
|
|
|
|
|
def insert( |
|
|
|
@@ -572,60 +573,54 @@ class InfinityConnection(DocStoreConnection): |
|
|
|
Helper functions for search result |
|
|
|
""" |
|
|
|
|
|
|
|
def getTotal(self, res: tuple[pl.DataFrame, int] | pl.DataFrame) -> int: |
|
|
|
def getTotal(self, res: tuple[pd.DataFrame, int] | pd.DataFrame) -> int: |
|
|
|
if isinstance(res, tuple): |
|
|
|
return res[1] |
|
|
|
return len(res) |
|
|
|
|
|
|
|
def getChunkIds(self, res: tuple[pl.DataFrame, int] | pl.DataFrame) -> list[str]: |
|
|
|
def getChunkIds(self, res: tuple[pd.DataFrame, int] | pd.DataFrame) -> list[str]: |
|
|
|
if isinstance(res, tuple): |
|
|
|
res = res[0] |
|
|
|
return list(res["id"]) |
|
|
|
|
|
|
|
def getFields(self, res: tuple[pl.DataFrame, int] | pl.DataFrame, fields: list[str]) -> list[str, dict]: |
|
|
|
def getFields(self, res: tuple[pd.DataFrame, int] | pd.DataFrame, fields: list[str]) -> dict[str, dict]: |
|
|
|
if isinstance(res, tuple): |
|
|
|
res = res[0] |
|
|
|
res_fields = {} |
|
|
|
if not fields: |
|
|
|
return {} |
|
|
|
num_rows = len(res) |
|
|
|
column_id = res["id"] |
|
|
|
for i in range(num_rows): |
|
|
|
id = column_id[i] |
|
|
|
m = {"id": id} |
|
|
|
for fieldnm in fields: |
|
|
|
if fieldnm not in res: |
|
|
|
m[fieldnm] = None |
|
|
|
continue |
|
|
|
v = res[fieldnm][i] |
|
|
|
if isinstance(v, Series): |
|
|
|
v = list(v) |
|
|
|
elif fieldnm in ["important_kwd", "question_kwd", "entities_kwd", "tag_kwd", "source_id"]: |
|
|
|
assert isinstance(v, str) |
|
|
|
v = [kwd for kwd in v.split("###") if kwd] |
|
|
|
elif fieldnm == "position_int": |
|
|
|
assert isinstance(v, str) |
|
|
|
fieldsAll = fields.copy() |
|
|
|
fieldsAll.append('id') |
|
|
|
column_map = {col.lower(): col for col in res.columns} |
|
|
|
matched_columns = {column_map[col.lower()]:col for col in set(fieldsAll) if col.lower() in column_map} |
|
|
|
none_columns = [col for col in set(fieldsAll) if col.lower() not in column_map] |
|
|
|
|
|
|
|
res2 = res[matched_columns.keys()] |
|
|
|
res2 = res2.rename(columns=matched_columns) |
|
|
|
res2.drop_duplicates(subset=['id'], inplace=True) |
|
|
|
|
|
|
|
for column in res2.columns: |
|
|
|
k = column.lower() |
|
|
|
if k in ["important_kwd", "question_kwd", "entities_kwd", "tag_kwd", "source_id"]: |
|
|
|
res2[column] = res2[column].apply(lambda v:[kwd for kwd in v.split("###") if kwd]) |
|
|
|
elif k == "position_int": |
|
|
|
def to_position_int(v): |
|
|
|
if v: |
|
|
|
arr = [int(hex_val, 16) for hex_val in v.split('_')] |
|
|
|
v = [arr[i:i + 5] for i in range(0, len(arr), 5)] |
|
|
|
else: |
|
|
|
v = [] |
|
|
|
elif fieldnm in ["page_num_int", "top_int"]: |
|
|
|
assert isinstance(v, str) |
|
|
|
if v: |
|
|
|
v = [int(hex_val, 16) for hex_val in v.split('_')] |
|
|
|
else: |
|
|
|
v = [] |
|
|
|
else: |
|
|
|
if not isinstance(v, str): |
|
|
|
v = str(v) |
|
|
|
# if fieldnm.endswith("_tks"): |
|
|
|
# v = rmSpace(v) |
|
|
|
m[fieldnm] = v |
|
|
|
res_fields[id] = m |
|
|
|
return res_fields |
|
|
|
|
|
|
|
def getHighlight(self, res: tuple[pl.DataFrame, int] | pl.DataFrame, keywords: list[str], fieldnm: str): |
|
|
|
return v |
|
|
|
res2[column] = res2[column].apply(to_position_int) |
|
|
|
elif k in ["page_num_int", "top_int"]: |
|
|
|
res2[column] = res2[column].apply(lambda v:[int(hex_val, 16) for hex_val in v.split('_')] if v else []) |
|
|
|
else: |
|
|
|
pass |
|
|
|
for column in none_columns: |
|
|
|
res2[column] = None |
|
|
|
|
|
|
|
return res2.set_index("id").to_dict(orient="index") |
|
|
|
|
|
|
|
def getHighlight(self, res: tuple[pd.DataFrame, int] | pd.DataFrame, keywords: list[str], fieldnm: str): |
|
|
|
if isinstance(res, tuple): |
|
|
|
res = res[0] |
|
|
|
ans = {} |
|
|
|
@@ -655,7 +650,7 @@ class InfinityConnection(DocStoreConnection): |
|
|
|
ans[id] = "...".join(txts) |
|
|
|
return ans |
|
|
|
|
|
|
|
def getAggregation(self, res: tuple[pl.DataFrame, int] | pl.DataFrame, fieldnm: str): |
|
|
|
def getAggregation(self, res: tuple[pd.DataFrame, int] | pd.DataFrame, fieldnm: str): |
|
|
|
""" |
|
|
|
TODO: Infinity doesn't provide aggregation |
|
|
|
""" |