|
|
|
@@ -44,8 +44,23 @@ from rag.utils.doc_store_conn import ( |
|
|
|
logger = logging.getLogger('ragflow.infinity_conn') |
|
|
|
|
|
|
|
|
|
|
|
def equivalent_condition_to_str(condition: dict) -> str | None: |
|
|
|
def equivalent_condition_to_str(condition: dict, table_instance=None) -> str | None: |
|
|
|
assert "_id" not in condition |
|
|
|
clmns = {} |
|
|
|
if table_instance: |
|
|
|
for n, ty, de, _ in table_instance.show_columns().rows(): |
|
|
|
clmns[n] = (ty, de) |
|
|
|
|
|
|
|
def exists(cln): |
|
|
|
nonlocal clmns |
|
|
|
assert cln in clmns, f"'{cln}' should be in '{clmns}'." |
|
|
|
ty, de = clmns[cln] |
|
|
|
if ty.lower().find("cha"): |
|
|
|
if not de: |
|
|
|
de = "" |
|
|
|
return f" {cln}!='{de}' " |
|
|
|
return f"{cln}!={de}" |
|
|
|
|
|
|
|
cond = list() |
|
|
|
for k, v in condition.items(): |
|
|
|
if not isinstance(k, str) or k in ["kb_id"] or not v: |
|
|
|
@@ -61,8 +76,15 @@ def equivalent_condition_to_str(condition: dict) -> str | None: |
|
|
|
strInCond = ", ".join(inCond) |
|
|
|
strInCond = f"{k} IN ({strInCond})" |
|
|
|
cond.append(strInCond) |
|
|
|
elif k == "must_not": |
|
|
|
if isinstance(v, dict): |
|
|
|
for kk, vv in v.items(): |
|
|
|
if kk == "exists": |
|
|
|
cond.append("NOT (%s)" % exists(vv)) |
|
|
|
elif isinstance(v, str): |
|
|
|
cond.append(f"{k}='{v}'") |
|
|
|
elif k == "exists": |
|
|
|
cond.append(exists(v)) |
|
|
|
else: |
|
|
|
cond.append(f"{k}={str(v)}") |
|
|
|
return " AND ".join(cond) if cond else "1=1" |
|
|
|
@@ -294,7 +316,11 @@ class InfinityConnection(DocStoreConnection): |
|
|
|
filter_cond = None |
|
|
|
filter_fulltext = "" |
|
|
|
if condition: |
|
|
|
filter_cond = equivalent_condition_to_str(condition) |
|
|
|
for indexName in indexNames: |
|
|
|
table_name = f"{indexName}_{knowledgebaseIds[0]}" |
|
|
|
filter_cond = equivalent_condition_to_str(condition, db_instance.get_table(table_name)) |
|
|
|
break |
|
|
|
|
|
|
|
for matchExpr in matchExprs: |
|
|
|
if isinstance(matchExpr, MatchTextExpr): |
|
|
|
if filter_cond and "filter" not in matchExpr.extra_options: |
|
|
|
@@ -434,12 +460,21 @@ class InfinityConnection(DocStoreConnection): |
|
|
|
self.createIdx(indexName, knowledgebaseId, vector_size) |
|
|
|
table_instance = db_instance.get_table(table_name) |
|
|
|
|
|
|
|
# embedding fields can't have a default value.... |
|
|
|
embedding_clmns = [] |
|
|
|
clmns = table_instance.show_columns().rows() |
|
|
|
for n, ty, _, _ in clmns: |
|
|
|
r = re.search(r"Embedding\([a-z]+,([0-9]+)\)", ty) |
|
|
|
if not r: |
|
|
|
continue |
|
|
|
embedding_clmns.append((n, int(r.group(1)))) |
|
|
|
|
|
|
|
docs = copy.deepcopy(documents) |
|
|
|
for d in docs: |
|
|
|
assert "_id" not in d |
|
|
|
assert "id" in d |
|
|
|
for k, v in d.items(): |
|
|
|
if k in ["important_kwd", "question_kwd", "entities_kwd", "tag_kwd"]: |
|
|
|
if k in ["important_kwd", "question_kwd", "entities_kwd", "tag_kwd", "source_id"]: |
|
|
|
assert isinstance(v, list) |
|
|
|
d[k] = "###".join(v) |
|
|
|
elif re.search(r"_feas$", k): |
|
|
|
@@ -454,6 +489,11 @@ class InfinityConnection(DocStoreConnection): |
|
|
|
elif k in ["page_num_int", "top_int"]: |
|
|
|
assert isinstance(v, list) |
|
|
|
d[k] = "_".join(f"{num:08x}" for num in v) |
|
|
|
|
|
|
|
for n, vs in embedding_clmns: |
|
|
|
if n in d: |
|
|
|
continue |
|
|
|
d[n] = [0] * vs |
|
|
|
ids = ["'{}'".format(d["id"]) for d in docs] |
|
|
|
str_ids = ", ".join(ids) |
|
|
|
str_filter = f"id IN ({str_ids})" |
|
|
|
@@ -475,11 +515,11 @@ class InfinityConnection(DocStoreConnection): |
|
|
|
db_instance = inf_conn.get_database(self.dbName) |
|
|
|
table_name = f"{indexName}_{knowledgebaseId}" |
|
|
|
table_instance = db_instance.get_table(table_name) |
|
|
|
if "exist" in condition: |
|
|
|
del condition["exist"] |
|
|
|
filter = equivalent_condition_to_str(condition) |
|
|
|
#if "exists" in condition: |
|
|
|
# del condition["exists"] |
|
|
|
filter = equivalent_condition_to_str(condition, table_instance) |
|
|
|
for k, v in list(newValue.items()): |
|
|
|
if k in ["important_kwd", "question_kwd", "entities_kwd", "tag_kwd"]: |
|
|
|
if k in ["important_kwd", "question_kwd", "entities_kwd", "tag_kwd", "source_id"]: |
|
|
|
assert isinstance(v, list) |
|
|
|
newValue[k] = "###".join(v) |
|
|
|
elif re.search(r"_feas$", k): |
|
|
|
@@ -496,9 +536,11 @@ class InfinityConnection(DocStoreConnection): |
|
|
|
elif k in ["page_num_int", "top_int"]: |
|
|
|
assert isinstance(v, list) |
|
|
|
newValue[k] = "_".join(f"{num:08x}" for num in v) |
|
|
|
elif k == "remove" and v in [PAGERANK_FLD]: |
|
|
|
elif k == "remove": |
|
|
|
del newValue[k] |
|
|
|
newValue[v] = 0 |
|
|
|
if v in [PAGERANK_FLD]: |
|
|
|
newValue[v] = 0 |
|
|
|
|
|
|
|
logger.debug(f"INFINITY update table {table_name}, filter {filter}, newValue {newValue}.") |
|
|
|
table_instance.update(filter, newValue) |
|
|
|
self.connPool.release_conn(inf_conn) |
|
|
|
@@ -508,14 +550,14 @@ class InfinityConnection(DocStoreConnection): |
|
|
|
inf_conn = self.connPool.get_conn() |
|
|
|
db_instance = inf_conn.get_database(self.dbName) |
|
|
|
table_name = f"{indexName}_{knowledgebaseId}" |
|
|
|
filter = equivalent_condition_to_str(condition) |
|
|
|
try: |
|
|
|
table_instance = db_instance.get_table(table_name) |
|
|
|
except Exception: |
|
|
|
logger.warning( |
|
|
|
f"Skipped deleting `{filter}` from table {table_name} since the table doesn't exist." |
|
|
|
f"Skipped deleting from table {table_name} since the table doesn't exist." |
|
|
|
) |
|
|
|
return 0 |
|
|
|
filter = equivalent_condition_to_str(condition, table_instance) |
|
|
|
logger.debug(f"INFINITY delete table {table_name}, filter {filter}.") |
|
|
|
res = table_instance.delete(filter) |
|
|
|
self.connPool.release_conn(inf_conn) |
|
|
|
@@ -553,7 +595,7 @@ class InfinityConnection(DocStoreConnection): |
|
|
|
v = res[fieldnm][i] |
|
|
|
if isinstance(v, Series): |
|
|
|
v = list(v) |
|
|
|
elif fieldnm in ["important_kwd", "question_kwd", "entities_kwd", "tag_kwd"]: |
|
|
|
elif fieldnm in ["important_kwd", "question_kwd", "entities_kwd", "tag_kwd", "source_id"]: |
|
|
|
assert isinstance(v, str) |
|
|
|
v = [kwd for kwd in v.split("###") if kwd] |
|
|
|
elif fieldnm == "position_int": |
|
|
|
@@ -584,6 +626,8 @@ class InfinityConnection(DocStoreConnection): |
|
|
|
ans = {} |
|
|
|
num_rows = len(res) |
|
|
|
column_id = res["id"] |
|
|
|
if fieldnm not in res: |
|
|
|
return {} |
|
|
|
for i in range(num_rows): |
|
|
|
id = column_id[i] |
|
|
|
txt = res[fieldnm][i] |