|
|
|
@@ -1,13 +1,10 @@ |
|
|
|
import copy |
|
|
|
import json |
|
|
|
import logging |
|
|
|
from collections.abc import Iterable |
|
|
|
from typing import Any, Optional |
|
|
|
|
|
|
|
from opensearchpy import OpenSearch |
|
|
|
from opensearchpy.helpers import bulk |
|
|
|
from pydantic import BaseModel, model_validator |
|
|
|
from tenacity import retry, stop_after_attempt, wait_fixed |
|
|
|
|
|
|
|
from configs import dify_config |
|
|
|
from core.rag.datasource.vdb.field import Field |
|
|
|
@@ -23,11 +20,15 @@ logger = logging.getLogger(__name__) |
|
|
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") |
|
|
|
logging.getLogger("lindorm").setLevel(logging.WARN) |
|
|
|
|
|
|
|
ROUTING_FIELD = "routing_field" |
|
|
|
UGC_INDEX_PREFIX = "ugc_index" |
|
|
|
|
|
|
|
|
|
|
|
class LindormVectorStoreConfig(BaseModel): |
|
|
|
hosts: str |
|
|
|
username: Optional[str] = None |
|
|
|
password: Optional[str] = None |
|
|
|
using_ugc: Optional[bool] = False |
|
|
|
|
|
|
|
@model_validator(mode="before") |
|
|
|
@classmethod |
|
|
|
@@ -41,9 +42,7 @@ class LindormVectorStoreConfig(BaseModel): |
|
|
|
return values |
|
|
|
|
|
|
|
def to_opensearch_params(self) -> dict[str, Any]: |
|
|
|
params = { |
|
|
|
"hosts": self.hosts, |
|
|
|
} |
|
|
|
params = {"hosts": self.hosts} |
|
|
|
if self.username and self.password: |
|
|
|
params["http_auth"] = (self.username, self.password) |
|
|
|
return params |
|
|
|
@@ -51,9 +50,21 @@ class LindormVectorStoreConfig(BaseModel): |
|
|
|
|
|
|
|
class LindormVectorStore(BaseVector): |
|
|
|
def __init__(self, collection_name: str, config: LindormVectorStoreConfig, **kwargs): |
|
|
|
super().__init__(collection_name.lower()) |
|
|
|
self._routing = None |
|
|
|
self._routing_field = None |
|
|
|
if config.using_ugc: |
|
|
|
routing_value: str = kwargs.get("routing_value") |
|
|
|
if routing_value is None: |
|
|
|
raise ValueError("UGC index should init vector with valid 'routing_value' parameter value") |
|
|
|
self._routing = routing_value.lower() |
|
|
|
self._routing_field = ROUTING_FIELD |
|
|
|
ugc_index_name = collection_name |
|
|
|
super().__init__(ugc_index_name.lower()) |
|
|
|
else: |
|
|
|
super().__init__(collection_name.lower()) |
|
|
|
self._client_config = config |
|
|
|
self._client = OpenSearch(**config.to_opensearch_params()) |
|
|
|
self._using_ugc = config.using_ugc |
|
|
|
self.kwargs = kwargs |
|
|
|
|
|
|
|
def get_type(self) -> str: |
|
|
|
@@ -66,89 +77,37 @@ class LindormVectorStore(BaseVector): |
|
|
|
def refresh(self): |
|
|
|
self._client.indices.refresh(index=self._collection_name) |
|
|
|
|
|
|
|
def __filter_existed_ids( |
|
|
|
self, |
|
|
|
texts: list[str], |
|
|
|
metadatas: list[dict], |
|
|
|
ids: list[str], |
|
|
|
bulk_size: int = 1024, |
|
|
|
) -> tuple[Iterable[str], Optional[list[dict]], Optional[list[str]]]: |
|
|
|
@retry(stop=stop_after_attempt(3), wait=wait_fixed(60)) |
|
|
|
def __fetch_existing_ids(batch_ids: list[str]) -> set[str]: |
|
|
|
try: |
|
|
|
existing_docs = self._client.mget(index=self._collection_name, body={"ids": batch_ids}, _source=False) |
|
|
|
return {doc["_id"] for doc in existing_docs["docs"] if doc["found"]} |
|
|
|
except Exception as e: |
|
|
|
logger.exception(f"Error fetching batch {batch_ids}") |
|
|
|
return set() |
|
|
|
|
|
|
|
@retry(stop=stop_after_attempt(3), wait=wait_fixed(60)) |
|
|
|
def __fetch_existing_routing_ids(batch_ids: list[str], route_ids: list[str]) -> set[str]: |
|
|
|
try: |
|
|
|
existing_docs = self._client.mget( |
|
|
|
body={ |
|
|
|
"docs": [ |
|
|
|
{"_index": self._collection_name, "_id": id, "routing": routing} |
|
|
|
for id, routing in zip(batch_ids, route_ids) |
|
|
|
] |
|
|
|
}, |
|
|
|
_source=False, |
|
|
|
) |
|
|
|
return {doc["_id"] for doc in existing_docs["docs"] if doc["found"]} |
|
|
|
except Exception as e: |
|
|
|
logger.exception(f"Error fetching batch ids: {batch_ids}") |
|
|
|
return set() |
|
|
|
|
|
|
|
if ids is None: |
|
|
|
return texts, metadatas, ids |
|
|
|
|
|
|
|
if len(texts) != len(ids): |
|
|
|
raise RuntimeError(f"texts {len(texts)} != {ids}") |
|
|
|
|
|
|
|
filtered_texts = [] |
|
|
|
filtered_metadatas = [] |
|
|
|
filtered_ids = [] |
|
|
|
|
|
|
|
def batch(iterable, n): |
|
|
|
length = len(iterable) |
|
|
|
for idx in range(0, length, n): |
|
|
|
yield iterable[idx : min(idx + n, length)] |
|
|
|
|
|
|
|
for ids_batch, texts_batch, metadatas_batch in zip( |
|
|
|
batch(ids, bulk_size), |
|
|
|
batch(texts, bulk_size), |
|
|
|
batch(metadatas, bulk_size) if metadatas is not None else batch([None] * len(ids), bulk_size), |
|
|
|
): |
|
|
|
existing_ids_set = __fetch_existing_ids(ids_batch) |
|
|
|
for text, metadata, doc_id in zip(texts_batch, metadatas_batch, ids_batch): |
|
|
|
if doc_id not in existing_ids_set: |
|
|
|
filtered_texts.append(text) |
|
|
|
filtered_ids.append(doc_id) |
|
|
|
if metadatas is not None: |
|
|
|
filtered_metadatas.append(metadata) |
|
|
|
|
|
|
|
return filtered_texts, metadatas if metadatas is None else filtered_metadatas, filtered_ids |
|
|
|
|
|
|
|
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): |
|
|
|
actions = [] |
|
|
|
uuids = self._get_uuids(documents) |
|
|
|
for i in range(len(documents)): |
|
|
|
action = { |
|
|
|
"_op_type": "index", |
|
|
|
"_index": self._collection_name.lower(), |
|
|
|
"_id": uuids[i], |
|
|
|
"_source": { |
|
|
|
Field.CONTENT_KEY.value: documents[i].page_content, |
|
|
|
Field.VECTOR.value: embeddings[i], # Make sure you pass an array here |
|
|
|
Field.METADATA_KEY.value: documents[i].metadata, |
|
|
|
}, |
|
|
|
action_header = { |
|
|
|
"index": { |
|
|
|
"_index": self.collection_name.lower(), |
|
|
|
"_id": uuids[i], |
|
|
|
} |
|
|
|
} |
|
|
|
action_values = { |
|
|
|
Field.CONTENT_KEY.value: documents[i].page_content, |
|
|
|
Field.VECTOR.value: embeddings[i], # Make sure you pass an array here |
|
|
|
Field.METADATA_KEY.value: documents[i].metadata, |
|
|
|
} |
|
|
|
actions.append(action) |
|
|
|
bulk(self._client, actions) |
|
|
|
self.refresh() |
|
|
|
if self._using_ugc: |
|
|
|
action_header["index"]["routing"] = self._routing |
|
|
|
action_values[self._routing_field] = self._routing |
|
|
|
actions.append(action_header) |
|
|
|
actions.append(action_values) |
|
|
|
response = self._client.bulk(actions) |
|
|
|
if response["errors"]: |
|
|
|
for item in response["items"]: |
|
|
|
print(f"{item['index']['status']}: {item['index']['error']['type']}") |
|
|
|
else: |
|
|
|
self.refresh() |
|
|
|
|
|
|
|
def get_ids_by_metadata_field(self, key: str, value: str): |
|
|
|
query = {"query": {"term": {f"{Field.METADATA_KEY.value}.{key}.keyword": value}}} |
|
|
|
query = {"query": {"bool": {"must": [{"term": {f"{Field.METADATA_KEY.value}.{key}.keyword": value}}]}}} |
|
|
|
if self._using_ugc: |
|
|
|
query["query"]["bool"]["must"].append({"term": {f"{self._routing_field}.keyword": self._routing}}) |
|
|
|
response = self._client.search(index=self._collection_name, body=query) |
|
|
|
if response["hits"]["hits"]: |
|
|
|
return [hit["_id"] for hit in response["hits"]["hits"]] |
|
|
|
@@ -156,50 +115,62 @@ class LindormVectorStore(BaseVector): |
|
|
|
return None |
|
|
|
|
|
|
|
def delete_by_metadata_field(self, key: str, value: str): |
|
|
|
query_str = {"query": {"match": {f"metadata.{key}": f"{value}"}}} |
|
|
|
results = self._client.search(index=self._collection_name, body=query_str) |
|
|
|
ids = [hit["_id"] for hit in results["hits"]["hits"]] |
|
|
|
ids = self.get_ids_by_metadata_field(key, value) |
|
|
|
if ids: |
|
|
|
self.delete_by_ids(ids) |
|
|
|
|
|
|
|
def delete_by_ids(self, ids: list[str]) -> None: |
|
|
|
params = {} |
|
|
|
if self._using_ugc: |
|
|
|
params["routing"] = self._routing |
|
|
|
for id in ids: |
|
|
|
if self._client.exists(index=self._collection_name, id=id): |
|
|
|
self._client.delete(index=self._collection_name, id=id) |
|
|
|
if self._client.exists(index=self._collection_name, id=id, params=params): |
|
|
|
params = {} |
|
|
|
if self._using_ugc: |
|
|
|
params["routing"] = self._routing |
|
|
|
self._client.delete(index=self._collection_name, id=id, params=params) |
|
|
|
self.refresh() |
|
|
|
else: |
|
|
|
logger.warning(f"DELETE BY ID: ID {id} does not exist in the index.") |
|
|
|
|
|
|
|
def delete(self) -> None: |
|
|
|
try: |
|
|
|
if self._using_ugc: |
|
|
|
routing_filter_query = { |
|
|
|
"query": {"bool": {"must": [{"term": {f"{self._routing_field}.keyword": self._routing}}]}} |
|
|
|
} |
|
|
|
self._client.delete_by_query(self._collection_name, body=routing_filter_query) |
|
|
|
self.refresh() |
|
|
|
else: |
|
|
|
if self._client.indices.exists(index=self._collection_name): |
|
|
|
self._client.indices.delete(index=self._collection_name, params={"timeout": 60}) |
|
|
|
logger.info("Delete index success") |
|
|
|
else: |
|
|
|
logger.warning(f"Index '{self._collection_name}' does not exist. No deletion performed.") |
|
|
|
except Exception as e: |
|
|
|
logger.exception(f"Error occurred while deleting the index: {self._collection_name}") |
|
|
|
raise e |
|
|
|
|
|
|
|
def text_exists(self, id: str) -> bool: |
|
|
|
try: |
|
|
|
self._client.get(index=self._collection_name, id=id) |
|
|
|
params = {} |
|
|
|
if self._using_ugc: |
|
|
|
params["routing"] = self._routing |
|
|
|
self._client.get(index=self._collection_name, id=id, params=params) |
|
|
|
return True |
|
|
|
except: |
|
|
|
return False |
|
|
|
|
|
|
|
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: |
|
|
|
# Make sure query_vector is a list |
|
|
|
if not isinstance(query_vector, list): |
|
|
|
raise ValueError("query_vector should be a list of floats") |
|
|
|
|
|
|
|
# Check whether query_vector is a floating-point number list |
|
|
|
if not all(isinstance(x, float) for x in query_vector): |
|
|
|
raise ValueError("All elements in query_vector should be floats") |
|
|
|
|
|
|
|
top_k = kwargs.get("top_k", 10) |
|
|
|
query = default_vector_search_query(query_vector=query_vector, k=top_k, **kwargs) |
|
|
|
try: |
|
|
|
response = self._client.search(index=self._collection_name, body=query) |
|
|
|
params = {} |
|
|
|
if self._using_ugc: |
|
|
|
params["routing"] = self._routing |
|
|
|
response = self._client.search(index=self._collection_name, body=query, params=params) |
|
|
|
except Exception as e: |
|
|
|
logger.exception(f"Error executing vector search, query: {query}") |
|
|
|
raise |
|
|
|
@@ -232,7 +203,7 @@ class LindormVectorStore(BaseVector): |
|
|
|
minimum_should_match = kwargs.get("minimum_should_match", 0) |
|
|
|
top_k = kwargs.get("top_k", 10) |
|
|
|
filters = kwargs.get("filter") |
|
|
|
routing = kwargs.get("routing") |
|
|
|
routing = self._routing |
|
|
|
full_text_query = default_text_search_query( |
|
|
|
query_text=query, |
|
|
|
k=top_k, |
|
|
|
@@ -243,6 +214,7 @@ class LindormVectorStore(BaseVector): |
|
|
|
minimum_should_match=minimum_should_match, |
|
|
|
filters=filters, |
|
|
|
routing=routing, |
|
|
|
routing_field=self._routing_field, |
|
|
|
) |
|
|
|
response = self._client.search(index=self._collection_name, body=full_text_query) |
|
|
|
docs = [] |
|
|
|
@@ -265,17 +237,18 @@ class LindormVectorStore(BaseVector): |
|
|
|
logger.info(f"Collection {self._collection_name} already exists.") |
|
|
|
return |
|
|
|
if self._client.indices.exists(index=self._collection_name): |
|
|
|
logger.info("{self._collection_name.lower()} already exists.") |
|
|
|
logger.info(f"{self._collection_name.lower()} already exists.") |
|
|
|
redis_client.set(collection_exist_cache_key, 1, ex=3600) |
|
|
|
return |
|
|
|
if len(self.kwargs) == 0 and len(kwargs) != 0: |
|
|
|
self.kwargs = copy.deepcopy(kwargs) |
|
|
|
vector_field = kwargs.pop("vector_field", Field.VECTOR.value) |
|
|
|
shards = kwargs.pop("shards", 2) |
|
|
|
shards = kwargs.pop("shards", 4) |
|
|
|
|
|
|
|
engine = kwargs.pop("engine", "lvector") |
|
|
|
method_name = kwargs.pop("method_name", "hnsw") |
|
|
|
method_name = kwargs.pop("method_name", dify_config.DEFAULT_INDEX_TYPE) |
|
|
|
space_type = kwargs.pop("space_type", dify_config.DEFAULT_DISTANCE_TYPE) |
|
|
|
data_type = kwargs.pop("data_type", "float") |
|
|
|
space_type = kwargs.pop("space_type", "cosinesimil") |
|
|
|
|
|
|
|
hnsw_m = kwargs.pop("hnsw_m", 24) |
|
|
|
hnsw_ef_construction = kwargs.pop("hnsw_ef_construction", 500) |
|
|
|
@@ -288,10 +261,10 @@ class LindormVectorStore(BaseVector): |
|
|
|
mapping = default_text_mapping( |
|
|
|
dimension, |
|
|
|
method_name, |
|
|
|
space_type=space_type, |
|
|
|
shards=shards, |
|
|
|
engine=engine, |
|
|
|
data_type=data_type, |
|
|
|
space_type=space_type, |
|
|
|
vector_field=vector_field, |
|
|
|
hnsw_m=hnsw_m, |
|
|
|
hnsw_ef_construction=hnsw_ef_construction, |
|
|
|
@@ -301,6 +274,7 @@ class LindormVectorStore(BaseVector): |
|
|
|
centroids_hnsw_m=centroids_hnsw_m, |
|
|
|
centroids_hnsw_ef_construct=centroids_hnsw_ef_construct, |
|
|
|
centroids_hnsw_ef_search=centroids_hnsw_ef_search, |
|
|
|
using_ugc=self._using_ugc, |
|
|
|
**kwargs, |
|
|
|
) |
|
|
|
self._client.indices.create(index=self._collection_name.lower(), body=mapping) |
|
|
|
@@ -309,15 +283,20 @@ class LindormVectorStore(BaseVector): |
|
|
|
|
|
|
|
|
|
|
|
def default_text_mapping(dimension: int, method_name: str, **kwargs: Any) -> dict: |
|
|
|
routing_field = kwargs.get("routing_field") |
|
|
|
excludes_from_source = kwargs.get("excludes_from_source") |
|
|
|
analyzer = kwargs.get("analyzer", "ik_max_word") |
|
|
|
text_field = kwargs.get("text_field", Field.CONTENT_KEY.value) |
|
|
|
engine = kwargs["engine"] |
|
|
|
shard = kwargs["shards"] |
|
|
|
space_type = kwargs["space_type"] |
|
|
|
space_type = kwargs.get("space_type") |
|
|
|
if space_type is None: |
|
|
|
if method_name == "hnsw": |
|
|
|
space_type = "l2" |
|
|
|
else: |
|
|
|
space_type = "cosine" |
|
|
|
data_type = kwargs["data_type"] |
|
|
|
vector_field = kwargs.get("vector_field", Field.VECTOR.value) |
|
|
|
using_ugc = kwargs.get("using_ugc", False) |
|
|
|
|
|
|
|
if method_name == "ivfpq": |
|
|
|
ivfpq_m = kwargs["ivfpq_m"] |
|
|
|
@@ -366,13 +345,11 @@ def default_text_mapping(dimension: int, method_name: str, **kwargs: Any) -> dic |
|
|
|
if excludes_from_source: |
|
|
|
mapping["mappings"]["_source"] = {"excludes": excludes_from_source} # e.g. {"excludes": ["vector_field"]} |
|
|
|
|
|
|
|
if method_name == "ivfpq" and routing_field is not None: |
|
|
|
if using_ugc and method_name == "ivfpq": |
|
|
|
mapping["settings"]["index"]["knn_routing"] = True |
|
|
|
mapping["settings"]["index"]["knn.offline.construction"] = True |
|
|
|
|
|
|
|
if method_name == "flat" and routing_field is not None: |
|
|
|
elif using_ugc and method_name == "hnsw" or using_ugc and method_name == "flat": |
|
|
|
mapping["settings"]["index"]["knn_routing"] = True |
|
|
|
|
|
|
|
return mapping |
|
|
|
|
|
|
|
|
|
|
|
@@ -386,14 +363,12 @@ def default_text_search_query( |
|
|
|
minimum_should_match: int = 0, |
|
|
|
filters: Optional[list[dict]] = None, |
|
|
|
routing: Optional[str] = None, |
|
|
|
routing_field: Optional[str] = None, |
|
|
|
**kwargs, |
|
|
|
) -> dict: |
|
|
|
if routing is not None: |
|
|
|
routing_field = kwargs.get("routing_field", "routing_field") |
|
|
|
query_clause = { |
|
|
|
"bool": { |
|
|
|
"must": [{"match": {text_field: query_text}}, {"term": {f"metadata.{routing_field}.keyword": routing}}] |
|
|
|
} |
|
|
|
"bool": {"must": [{"match": {text_field: query_text}}, {"term": {f"{routing_field}.keyword": routing}}]} |
|
|
|
} |
|
|
|
else: |
|
|
|
query_clause = {"match": {text_field: query_text}} |
|
|
|
@@ -483,16 +458,40 @@ def default_vector_search_query( |
|
|
|
|
|
|
|
class LindormVectorStoreFactory(AbstractVectorFactory): |
|
|
|
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> LindormVectorStore: |
|
|
|
if dataset.index_struct_dict: |
|
|
|
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] |
|
|
|
collection_name = class_prefix |
|
|
|
else: |
|
|
|
dataset_id = dataset.id |
|
|
|
collection_name = Dataset.gen_collection_name_by_id(dataset_id) |
|
|
|
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.LINDORM, collection_name)) |
|
|
|
lindorm_config = LindormVectorStoreConfig( |
|
|
|
hosts=dify_config.LINDORM_URL, |
|
|
|
username=dify_config.LINDORM_USERNAME, |
|
|
|
password=dify_config.LINDORM_PASSWORD, |
|
|
|
using_ugc=dify_config.USING_UGC_INDEX, |
|
|
|
) |
|
|
|
return LindormVectorStore(collection_name, lindorm_config) |
|
|
|
using_ugc = dify_config.USING_UGC_INDEX |
|
|
|
routing_value = None |
|
|
|
if dataset.index_struct: |
|
|
|
if using_ugc: |
|
|
|
dimension = dataset.index_struct_dict["dimension"] |
|
|
|
index_type = dataset.index_struct_dict["index_type"] |
|
|
|
distance_type = dataset.index_struct_dict["distance_type"] |
|
|
|
index_name = f"{UGC_INDEX_PREFIX}_{dimension}_{index_type}_{distance_type}" |
|
|
|
routing_value = dataset.index_struct_dict["vector_store"]["class_prefix"] |
|
|
|
else: |
|
|
|
index_name = dataset.index_struct_dict["vector_store"]["class_prefix"] |
|
|
|
else: |
|
|
|
embedding_vector = embeddings.embed_query("hello word") |
|
|
|
dimension = len(embedding_vector) |
|
|
|
index_type = dify_config.DEFAULT_INDEX_TYPE |
|
|
|
distance_type = dify_config.DEFAULT_DISTANCE_TYPE |
|
|
|
class_prefix = Dataset.gen_collection_name_by_id(dataset.id) |
|
|
|
index_struct_dict = { |
|
|
|
"type": VectorType.LINDORM, |
|
|
|
"vector_store": {"class_prefix": class_prefix}, |
|
|
|
"index_type": index_type, |
|
|
|
"dimension": dimension, |
|
|
|
"distance_type": distance_type, |
|
|
|
} |
|
|
|
dataset.index_struct = json.dumps(index_struct_dict) |
|
|
|
if using_ugc: |
|
|
|
index_name = f"{UGC_INDEX_PREFIX}_{dimension}_{index_type}_{distance_type}" |
|
|
|
routing_value = class_prefix |
|
|
|
else: |
|
|
|
index_name = class_prefix |
|
|
|
return LindormVectorStore(index_name, lindorm_config, routing_value=routing_value) |