|  |  | @@ -13,7 +13,8 @@ from rag import settings | 
		
	
		
			
			|  |  |  | from rag.utils import singleton | 
		
	
		
			
			|  |  |  | from api.utils.file_utils import get_project_base_directory | 
		
	
		
			
			|  |  |  | import polars as pl | 
		
	
		
			
			|  |  |  | from rag.utils.doc_store_conn import DocStoreConnection, MatchExpr, OrderByExpr, MatchTextExpr, MatchDenseExpr, FusionExpr | 
		
	
		
			
			|  |  |  | from rag.utils.doc_store_conn import DocStoreConnection, MatchExpr, OrderByExpr, MatchTextExpr, MatchDenseExpr, \ | 
		
	
		
			
			|  |  |  | FusionExpr | 
		
	
		
			
			|  |  |  | from rag.nlp import is_english, rag_tokenizer | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | 
 | 
		
	
	
		
			
			|  |  | @@ -26,7 +27,8 @@ class ESConnection(DocStoreConnection): | 
		
	
		
			
			|  |  |  | try: | 
		
	
		
			
			|  |  |  | self.es = Elasticsearch( | 
		
	
		
			
			|  |  |  | settings.ES["hosts"].split(","), | 
		
	
		
			
			|  |  |  | basic_auth=(settings.ES["username"], settings.ES["password"]) if "username" in settings.ES and "password" in settings.ES else None, | 
		
	
		
			
			|  |  |  | basic_auth=(settings.ES["username"], settings.ES[ | 
		
	
		
			
			|  |  |  | "password"]) if "username" in settings.ES and "password" in settings.ES else None, | 
		
	
		
			
			|  |  |  | verify_certs=False, | 
		
	
		
			
			|  |  |  | timeout=600 | 
		
	
		
			
			|  |  |  | ) | 
		
	
	
		
			
			|  |  | @@ -57,6 +59,7 @@ class ESConnection(DocStoreConnection): | 
		
	
		
			
			|  |  |  | """ | 
		
	
		
			
			|  |  |  | Database operations | 
		
	
		
			
			|  |  |  | """ | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | def dbType(self) -> str: | 
		
	
		
			
			|  |  |  | return "elasticsearch" | 
		
	
		
			
			|  |  |  | 
 | 
		
	
	
		
			
			|  |  | @@ -66,6 +69,7 @@ class ESConnection(DocStoreConnection): | 
		
	
		
			
			|  |  |  | """ | 
		
	
		
			
			|  |  |  | Table operations | 
		
	
		
			
			|  |  |  | """ | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | def createIdx(self, indexName: str, knowledgebaseId: str, vectorSize: int): | 
		
	
		
			
			|  |  |  | if self.indexExist(indexName, knowledgebaseId): | 
		
	
		
			
			|  |  |  | return True | 
		
	
	
		
			
			|  |  | @@ -97,7 +101,10 @@ class ESConnection(DocStoreConnection): | 
		
	
		
			
			|  |  |  | """ | 
		
	
		
			
			|  |  |  | CRUD operations | 
		
	
		
			
			|  |  |  | """ | 
		
	
		
			
			|  |  |  | def search(self, selectFields: list[str], highlightFields: list[str], condition: dict, matchExprs: list[MatchExpr], orderBy: OrderByExpr, offset: int, limit: int, indexNames: str|list[str], knowledgebaseIds: list[str]) -> list[dict] | pl.DataFrame: | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | def search(self, selectFields: list[str], highlightFields: list[str], condition: dict, matchExprs: list[MatchExpr], | 
		
	
		
			
			|  |  |  | orderBy: OrderByExpr, offset: int, limit: int, indexNames: str | list[str], | 
		
	
		
			
			|  |  |  | knowledgebaseIds: list[str]) -> list[dict] | pl.DataFrame: | 
		
	
		
			
			|  |  |  | """ | 
		
	
		
			
			|  |  |  | Refers to https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl.html | 
		
	
		
			
			|  |  |  | """ | 
		
	
	
		
			
			|  |  | @@ -109,8 +116,10 @@ class ESConnection(DocStoreConnection): | 
		
	
		
			
			|  |  |  | bqry = None | 
		
	
		
			
			|  |  |  | vector_similarity_weight = 0.5 | 
		
	
		
			
			|  |  |  | for m in matchExprs: | 
		
	
		
			
			|  |  |  | if isinstance(m, FusionExpr) and m.method=="weighted_sum" and "weights" in m.fusion_params: | 
		
	
		
			
			|  |  |  | assert len(matchExprs)==3 and isinstance(matchExprs[0], MatchTextExpr) and isinstance(matchExprs[1], MatchDenseExpr) and isinstance(matchExprs[2], FusionExpr) | 
		
	
		
			
			|  |  |  | if isinstance(m, FusionExpr) and m.method == "weighted_sum" and "weights" in m.fusion_params: | 
		
	
		
			
			|  |  |  | assert len(matchExprs) == 3 and isinstance(matchExprs[0], MatchTextExpr) and isinstance(matchExprs[1], | 
		
	
		
			
			|  |  |  | MatchDenseExpr) and isinstance( | 
		
	
		
			
			|  |  |  | matchExprs[2], FusionExpr) | 
		
	
		
			
			|  |  |  | weights = m.fusion_params["weights"] | 
		
	
		
			
			|  |  |  | vector_similarity_weight = float(weights.split(",")[1]) | 
		
	
		
			
			|  |  |  | for m in matchExprs: | 
		
	
	
		
			
			|  |  | @@ -119,36 +128,41 @@ class ESConnection(DocStoreConnection): | 
		
	
		
			
			|  |  |  | if "minimum_should_match" in m.extra_options: | 
		
	
		
			
			|  |  |  | minimum_should_match = str(int(m.extra_options["minimum_should_match"] * 100)) + "%" | 
		
	
		
			
			|  |  |  | bqry = Q("bool", | 
		
	
		
			
			|  |  |  | must=Q("query_string", fields=m.fields, | 
		
	
		
			
			|  |  |  | must=Q("query_string", fields=m.fields, | 
		
	
		
			
			|  |  |  | type="best_fields", query=m.matching_text, | 
		
	
		
			
			|  |  |  | minimum_should_match = minimum_should_match, | 
		
	
		
			
			|  |  |  | minimum_should_match=minimum_should_match, | 
		
	
		
			
			|  |  |  | boost=1), | 
		
	
		
			
			|  |  |  | boost = 1.0 - vector_similarity_weight, | 
		
	
		
			
			|  |  |  | ) | 
		
	
		
			
			|  |  |  | if condition: | 
		
	
		
			
			|  |  |  | for k, v in condition.items(): | 
		
	
		
			
			|  |  |  | if not isinstance(k, str) or not v: | 
		
	
		
			
			|  |  |  | continue | 
		
	
		
			
			|  |  |  | if isinstance(v, list): | 
		
	
		
			
			|  |  |  | bqry.filter.append(Q("terms", **{k: v})) | 
		
	
		
			
			|  |  |  | elif isinstance(v, str) or isinstance(v, int): | 
		
	
		
			
			|  |  |  | bqry.filter.append(Q("term", **{k: v})) | 
		
	
		
			
			|  |  |  | else: | 
		
	
		
			
			|  |  |  | raise Exception(f"Condition `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str or list.") | 
		
	
		
			
			|  |  |  | boost=1.0 - vector_similarity_weight, | 
		
	
		
			
			|  |  |  | ) | 
		
	
		
			
			|  |  |  | elif isinstance(m, MatchDenseExpr): | 
		
	
		
			
			|  |  |  | assert(bqry is not None) | 
		
	
		
			
			|  |  |  | assert (bqry is not None) | 
		
	
		
			
			|  |  |  | similarity = 0.0 | 
		
	
		
			
			|  |  |  | if "similarity" in m.extra_options: | 
		
	
		
			
			|  |  |  | similarity = m.extra_options["similarity"] | 
		
	
		
			
			|  |  |  | s = s.knn(m.vector_column_name, | 
		
	
		
			
			|  |  |  | m.topn, | 
		
	
		
			
			|  |  |  | m.topn * 2, | 
		
	
		
			
			|  |  |  | query_vector = list(m.embedding_data), | 
		
	
		
			
			|  |  |  | filter = bqry.to_dict(), | 
		
	
		
			
			|  |  |  | similarity = similarity, | 
		
	
		
			
			|  |  |  | ) | 
		
	
		
			
			|  |  |  | if matchExprs: | 
		
	
		
			
			|  |  |  | s.query = bqry | 
		
	
		
			
			|  |  |  | m.topn, | 
		
	
		
			
			|  |  |  | m.topn * 2, | 
		
	
		
			
			|  |  |  | query_vector=list(m.embedding_data), | 
		
	
		
			
			|  |  |  | filter=bqry.to_dict(), | 
		
	
		
			
			|  |  |  | similarity=similarity, | 
		
	
		
			
			|  |  |  | ) | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | if condition: | 
		
	
		
			
			|  |  |  | if not bqry: | 
		
	
		
			
			|  |  |  | bqry = Q("bool", must=[]) | 
		
	
		
			
			|  |  |  | for k, v in condition.items(): | 
		
	
		
			
			|  |  |  | if not isinstance(k, str) or not v: | 
		
	
		
			
			|  |  |  | continue | 
		
	
		
			
			|  |  |  | if isinstance(v, list): | 
		
	
		
			
			|  |  |  | bqry.filter.append(Q("terms", **{k: v})) | 
		
	
		
			
			|  |  |  | elif isinstance(v, str) or isinstance(v, int): | 
		
	
		
			
			|  |  |  | bqry.filter.append(Q("term", **{k: v})) | 
		
	
		
			
			|  |  |  | else: | 
		
	
		
			
			|  |  |  | raise Exception( | 
		
	
		
			
			|  |  |  | f"Condition `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str or list.") | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | if bqry: | 
		
	
		
			
			|  |  |  | s = s.query(bqry) | 
		
	
		
			
			|  |  |  | for field in highlightFields: | 
		
	
		
			
			|  |  |  | s = s.highlight(field) | 
		
	
		
			
			|  |  |  | 
 | 
		
	
	
		
			
			|  |  | @@ -157,12 +171,13 @@ class ESConnection(DocStoreConnection): | 
		
	
		
			
			|  |  |  | for field, order in orderBy.fields: | 
		
	
		
			
			|  |  |  | order = "asc" if order == 0 else "desc" | 
		
	
		
			
			|  |  |  | orders.append({field: {"order": order, "unmapped_type": "float", | 
		
	
		
			
			|  |  |  | "mode": "avg", "numeric_type": "double"}}) | 
		
	
		
			
			|  |  |  | "mode": "avg", "numeric_type": "double"}}) | 
		
	
		
			
			|  |  |  | s = s.sort(*orders) | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | if limit > 0: | 
		
	
		
			
			|  |  |  | s = s[offset:limit] | 
		
	
		
			
			|  |  |  | q = s.to_dict() | 
		
	
		
			
			|  |  |  | print(json.dumps(q), flush=True) | 
		
	
		
			
			|  |  |  | # logger.info("ESConnection.search [Q]: " + json.dumps(q)) | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | for i in range(3): | 
		
	
	
		
			
			|  |  | @@ -189,7 +204,7 @@ class ESConnection(DocStoreConnection): | 
		
	
		
			
			|  |  |  | for i in range(3): | 
		
	
		
			
			|  |  |  | try: | 
		
	
		
			
			|  |  |  | res = self.es.get(index=(indexName), | 
		
	
		
			
			|  |  |  | id=chunkId, source=True,) | 
		
	
		
			
			|  |  |  | id=chunkId, source=True, ) | 
		
	
		
			
			|  |  |  | if str(res.get("timed_out", "")).lower() == "true": | 
		
	
		
			
			|  |  |  | raise Exception("Es Timeout.") | 
		
	
		
			
			|  |  |  | if not res.get("found"): | 
		
	
	
		
			
			|  |  | @@ -222,7 +237,7 @@ class ESConnection(DocStoreConnection): | 
		
	
		
			
			|  |  |  | for _ in range(100): | 
		
	
		
			
			|  |  |  | try: | 
		
	
		
			
			|  |  |  | r = self.es.bulk(index=(indexName), operations=operations, | 
		
	
		
			
			|  |  |  | refresh=False, timeout="600s") | 
		
	
		
			
			|  |  |  | refresh=False, timeout="600s") | 
		
	
		
			
			|  |  |  | if re.search(r"False", str(r["errors"]), re.IGNORECASE): | 
		
	
		
			
			|  |  |  | return res | 
		
	
		
			
			|  |  |  | 
 | 
		
	
	
		
			
			|  |  | @@ -249,7 +264,8 @@ class ESConnection(DocStoreConnection): | 
		
	
		
			
			|  |  |  | self.es.update(index=indexName, id=chunkId, doc=doc) | 
		
	
		
			
			|  |  |  | return True | 
		
	
		
			
			|  |  |  | except Exception as e: | 
		
	
		
			
			|  |  |  | logger.exception(f"ES failed to update(index={indexName}, id={id}, doc={json.dumps(condition, ensure_ascii=False)})") | 
		
	
		
			
			|  |  |  | logger.exception( | 
		
	
		
			
			|  |  |  | f"ES failed to update(index={indexName}, id={id}, doc={json.dumps(condition, ensure_ascii=False)})") | 
		
	
		
			
			|  |  |  | if str(e).find("Timeout") > 0: | 
		
	
		
			
			|  |  |  | continue | 
		
	
		
			
			|  |  |  | else: | 
		
	
	
		
			
			|  |  | @@ -263,7 +279,8 @@ class ESConnection(DocStoreConnection): | 
		
	
		
			
			|  |  |  | elif isinstance(v, str) or isinstance(v, int): | 
		
	
		
			
			|  |  |  | bqry.filter.append(Q("term", **{k: v})) | 
		
	
		
			
			|  |  |  | else: | 
		
	
		
			
			|  |  |  | raise Exception(f"Condition `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str or list.") | 
		
	
		
			
			|  |  |  | raise Exception( | 
		
	
		
			
			|  |  |  | f"Condition `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str or list.") | 
		
	
		
			
			|  |  |  | scripts = [] | 
		
	
		
			
			|  |  |  | for k, v in newValue.items(): | 
		
	
		
			
			|  |  |  | if not isinstance(k, str) or not v: | 
		
	
	
		
			
			|  |  | @@ -273,7 +290,8 @@ class ESConnection(DocStoreConnection): | 
		
	
		
			
			|  |  |  | elif isinstance(v, int): | 
		
	
		
			
			|  |  |  | scripts.append(f"ctx._source.{k} = {v}") | 
		
	
		
			
			|  |  |  | else: | 
		
	
		
			
			|  |  |  | raise Exception(f"newValue `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str.") | 
		
	
		
			
			|  |  |  | raise Exception( | 
		
	
		
			
			|  |  |  | f"newValue `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str.") | 
		
	
		
			
			|  |  |  | ubq = UpdateByQuery( | 
		
	
		
			
			|  |  |  | index=indexName).using( | 
		
	
		
			
			|  |  |  | self.es).query(bqry) | 
		
	
	
		
			
			|  |  | @@ -313,7 +331,7 @@ class ESConnection(DocStoreConnection): | 
		
	
		
			
			|  |  |  | try: | 
		
	
		
			
			|  |  |  | res = self.es.delete_by_query( | 
		
	
		
			
			|  |  |  | index=indexName, | 
		
	
		
			
			|  |  |  | body = Search().query(qry).to_dict(), | 
		
	
		
			
			|  |  |  | body=Search().query(qry).to_dict(), | 
		
	
		
			
			|  |  |  | refresh=True) | 
		
	
		
			
			|  |  |  | return res["deleted"] | 
		
	
		
			
			|  |  |  | except Exception as e: | 
		
	
	
		
			
			|  |  | @@ -325,10 +343,10 @@ class ESConnection(DocStoreConnection): | 
		
	
		
			
			|  |  |  | return 0 | 
		
	
		
			
			|  |  |  | return 0 | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | """ | 
		
	
		
			
			|  |  |  | Helper functions for search result | 
		
	
		
			
			|  |  |  | """ | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | def getTotal(self, res): | 
		
	
		
			
			|  |  |  | if isinstance(res["hits"]["total"], type({})): | 
		
	
		
			
			|  |  |  | return res["hits"]["total"]["value"] | 
		
	
	
		
			
			|  |  | @@ -376,12 +394,13 @@ class ESConnection(DocStoreConnection): | 
		
	
		
			
			|  |  |  | continue | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | txt = d["_source"][fieldnm] | 
		
	
		
			
			|  |  |  | txt = re.sub(r"[\r\n]", " ", txt, flags=re.IGNORECASE|re.MULTILINE) | 
		
	
		
			
			|  |  |  | txt = re.sub(r"[\r\n]", " ", txt, flags=re.IGNORECASE | re.MULTILINE) | 
		
	
		
			
			|  |  |  | txts = [] | 
		
	
		
			
			|  |  |  | for t in re.split(r"[.?!;\n]", txt): | 
		
	
		
			
			|  |  |  | for w in keywords: | 
		
	
		
			
			|  |  |  | t = re.sub(r"(^|[ .?/'\"\(\)!,:;-])(%s)([ .?/'\"\(\)!,:;-])"%re.escape(w), r"\1<em>\2</em>\3", t, flags=re.IGNORECASE|re.MULTILINE) | 
		
	
		
			
			|  |  |  | if not re.search(r"<em>[^<>]+</em>", t, flags=re.IGNORECASE|re.MULTILINE): | 
		
	
		
			
			|  |  |  | t = re.sub(r"(^|[ .?/'\"\(\)!,:;-])(%s)([ .?/'\"\(\)!,:;-])" % re.escape(w), r"\1<em>\2</em>\3", t, | 
		
	
		
			
			|  |  |  | flags=re.IGNORECASE | re.MULTILINE) | 
		
	
		
			
			|  |  |  | if not re.search(r"<em>[^<>]+</em>", t, flags=re.IGNORECASE | re.MULTILINE): | 
		
	
		
			
			|  |  |  | continue | 
		
	
		
			
			|  |  |  | txts.append(t) | 
		
	
		
			
			|  |  |  | ans[d["_id"]] = "...".join(txts) if txts else "...".join([a for a in list(hlts.items())[0][1]]) | 
		
	
	
		
			
			|  |  | @@ -395,10 +414,10 @@ class ESConnection(DocStoreConnection): | 
		
	
		
			
			|  |  |  | bkts = res["aggregations"][agg_field]["buckets"] | 
		
	
		
			
			|  |  |  | return [(b["key"], b["doc_count"]) for b in bkts] | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | """ | 
		
	
		
			
			|  |  |  | SQL | 
		
	
		
			
			|  |  |  | """ | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | def sql(self, sql: str, fetch_size: int, format: str): | 
		
	
		
			
			|  |  |  | logger.info(f"ESConnection.sql get sql: {sql}") | 
		
	
		
			
			|  |  |  | sql = re.sub(r"[ `]+", " ", sql) | 
		
	
	
		
			
			|  |  | @@ -413,7 +432,7 @@ class ESConnection(DocStoreConnection): | 
		
	
		
			
			|  |  |  | r.group(1), | 
		
	
		
			
			|  |  |  | r.group(2), | 
		
	
		
			
			|  |  |  | r.group(3)), | 
		
	
		
			
			|  |  |  | match)) | 
		
	
		
			
			|  |  |  | match)) | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | for p, r in replaces: | 
		
	
		
			
			|  |  |  | sql = sql.replace(p, r, 1) | 
		
	
	
		
			
			|  |  | @@ -421,7 +440,8 @@ class ESConnection(DocStoreConnection): | 
		
	
		
			
			|  |  |  | 
 | 
		
	
		
			
			|  |  |  | for i in range(3): | 
		
	
		
			
			|  |  |  | try: | 
		
	
		
			
			|  |  |  | res = self.es.sql.query(body={"query": sql, "fetch_size": fetch_size}, format=format, request_timeout="2s") | 
		
	
		
			
			|  |  |  | res = self.es.sql.query(body={"query": sql, "fetch_size": fetch_size}, format=format, | 
		
	
		
			
			|  |  |  | request_timeout="2s") | 
		
	
		
			
			|  |  |  | return res | 
		
	
		
			
			|  |  |  | except ConnectionTimeout: | 
		
	
		
			
			|  |  |  | logger.exception("ESConnection.sql timeout [Q]: " + sql) |