Browse Source

Lindorm vdb (#11574)

Co-authored-by: jiangzhijie <jiangzhijie.jzj@alibaba-inc.com>
tags/0.14.0
Jiang 10 months ago
parent
commit
0d04cdc323
No account linked to committer's email address

+ 1
- 0
api/.env.example View File

LINDORM_URL=http://ld-*******************-proxy-search-pub.lindorm.aliyuncs.com:30070 LINDORM_URL=http://ld-*******************-proxy-search-pub.lindorm.aliyuncs.com:30070
LINDORM_USERNAME=admin LINDORM_USERNAME=admin
LINDORM_PASSWORD=admin LINDORM_PASSWORD=admin
USING_UGC_INDEX=False


# OceanBase Vector configuration # OceanBase Vector configuration
OCEANBASE_VECTOR_HOST=127.0.0.1 OCEANBASE_VECTOR_HOST=127.0.0.1

+ 11
- 0
api/configs/middleware/vdb/lindorm_config.py View File

description="Lindorm password", description="Lindorm password",
default=None, default=None,
) )
DEFAULT_INDEX_TYPE: Optional[str] = Field(
description="Lindorm Vector Index Type, hnsw or flat is available in dify",
default="hnsw",
)
DEFAULT_DISTANCE_TYPE: Optional[str] = Field(
description="Vector Distance Type, support l2, cosinesimil, innerproduct", default="l2"
)
USING_UGC_INDEX: Optional[bool] = Field(
description="Using UGC index will store the same type of Index in a single index but can retrieve separately.",
default=False,
)

+ 119
- 120
api/core/rag/datasource/vdb/lindorm/lindorm_vector.py View File

import copy import copy
import json import json
import logging import logging
from collections.abc import Iterable
from typing import Any, Optional from typing import Any, Optional


from opensearchpy import OpenSearch from opensearchpy import OpenSearch
from opensearchpy.helpers import bulk
from pydantic import BaseModel, model_validator from pydantic import BaseModel, model_validator
from tenacity import retry, stop_after_attempt, wait_fixed


from configs import dify_config from configs import dify_config
from core.rag.datasource.vdb.field import Field from core.rag.datasource.vdb.field import Field
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logging.getLogger("lindorm").setLevel(logging.WARN) logging.getLogger("lindorm").setLevel(logging.WARN)


ROUTING_FIELD = "routing_field"
UGC_INDEX_PREFIX = "ugc_index"



class LindormVectorStoreConfig(BaseModel): class LindormVectorStoreConfig(BaseModel):
hosts: str hosts: str
username: Optional[str] = None username: Optional[str] = None
password: Optional[str] = None password: Optional[str] = None
using_ugc: Optional[bool] = False


@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
return values return values


def to_opensearch_params(self) -> dict[str, Any]: def to_opensearch_params(self) -> dict[str, Any]:
params = {
"hosts": self.hosts,
}
params = {"hosts": self.hosts}
if self.username and self.password: if self.username and self.password:
params["http_auth"] = (self.username, self.password) params["http_auth"] = (self.username, self.password)
return params return params


class LindormVectorStore(BaseVector): class LindormVectorStore(BaseVector):
def __init__(self, collection_name: str, config: LindormVectorStoreConfig, **kwargs): def __init__(self, collection_name: str, config: LindormVectorStoreConfig, **kwargs):
super().__init__(collection_name.lower())
self._routing = None
self._routing_field = None
if config.using_ugc:
routing_value: str = kwargs.get("routing_value")
if routing_value is None:
raise ValueError("UGC index should init vector with valid 'routing_value' parameter value")
self._routing = routing_value.lower()
self._routing_field = ROUTING_FIELD
ugc_index_name = collection_name
super().__init__(ugc_index_name.lower())
else:
super().__init__(collection_name.lower())
self._client_config = config self._client_config = config
self._client = OpenSearch(**config.to_opensearch_params()) self._client = OpenSearch(**config.to_opensearch_params())
self._using_ugc = config.using_ugc
self.kwargs = kwargs self.kwargs = kwargs


def get_type(self) -> str: def get_type(self) -> str:
def refresh(self): def refresh(self):
self._client.indices.refresh(index=self._collection_name) self._client.indices.refresh(index=self._collection_name)


def __filter_existed_ids(
self,
texts: list[str],
metadatas: list[dict],
ids: list[str],
bulk_size: int = 1024,
) -> tuple[Iterable[str], Optional[list[dict]], Optional[list[str]]]:
@retry(stop=stop_after_attempt(3), wait=wait_fixed(60))
def __fetch_existing_ids(batch_ids: list[str]) -> set[str]:
try:
existing_docs = self._client.mget(index=self._collection_name, body={"ids": batch_ids}, _source=False)
return {doc["_id"] for doc in existing_docs["docs"] if doc["found"]}
except Exception as e:
logger.exception(f"Error fetching batch {batch_ids}")
return set()

@retry(stop=stop_after_attempt(3), wait=wait_fixed(60))
def __fetch_existing_routing_ids(batch_ids: list[str], route_ids: list[str]) -> set[str]:
try:
existing_docs = self._client.mget(
body={
"docs": [
{"_index": self._collection_name, "_id": id, "routing": routing}
for id, routing in zip(batch_ids, route_ids)
]
},
_source=False,
)
return {doc["_id"] for doc in existing_docs["docs"] if doc["found"]}
except Exception as e:
logger.exception(f"Error fetching batch ids: {batch_ids}")
return set()

if ids is None:
return texts, metadatas, ids

if len(texts) != len(ids):
raise RuntimeError(f"texts {len(texts)} != {ids}")

filtered_texts = []
filtered_metadatas = []
filtered_ids = []

def batch(iterable, n):
length = len(iterable)
for idx in range(0, length, n):
yield iterable[idx : min(idx + n, length)]

for ids_batch, texts_batch, metadatas_batch in zip(
batch(ids, bulk_size),
batch(texts, bulk_size),
batch(metadatas, bulk_size) if metadatas is not None else batch([None] * len(ids), bulk_size),
):
existing_ids_set = __fetch_existing_ids(ids_batch)
for text, metadata, doc_id in zip(texts_batch, metadatas_batch, ids_batch):
if doc_id not in existing_ids_set:
filtered_texts.append(text)
filtered_ids.append(doc_id)
if metadatas is not None:
filtered_metadatas.append(metadata)

return filtered_texts, metadatas if metadatas is None else filtered_metadatas, filtered_ids

def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
actions = [] actions = []
uuids = self._get_uuids(documents) uuids = self._get_uuids(documents)
for i in range(len(documents)): for i in range(len(documents)):
action = {
"_op_type": "index",
"_index": self._collection_name.lower(),
"_id": uuids[i],
"_source": {
Field.CONTENT_KEY.value: documents[i].page_content,
Field.VECTOR.value: embeddings[i], # Make sure you pass an array here
Field.METADATA_KEY.value: documents[i].metadata,
},
action_header = {
"index": {
"_index": self.collection_name.lower(),
"_id": uuids[i],
}
}
action_values = {
Field.CONTENT_KEY.value: documents[i].page_content,
Field.VECTOR.value: embeddings[i], # Make sure you pass an array here
Field.METADATA_KEY.value: documents[i].metadata,
} }
actions.append(action)
bulk(self._client, actions)
self.refresh()
if self._using_ugc:
action_header["index"]["routing"] = self._routing
action_values[self._routing_field] = self._routing
actions.append(action_header)
actions.append(action_values)
response = self._client.bulk(actions)
if response["errors"]:
for item in response["items"]:
print(f"{item['index']['status']}: {item['index']['error']['type']}")
else:
self.refresh()


def get_ids_by_metadata_field(self, key: str, value: str): def get_ids_by_metadata_field(self, key: str, value: str):
query = {"query": {"term": {f"{Field.METADATA_KEY.value}.{key}.keyword": value}}}
query = {"query": {"bool": {"must": [{"term": {f"{Field.METADATA_KEY.value}.{key}.keyword": value}}]}}}
if self._using_ugc:
query["query"]["bool"]["must"].append({"term": {f"{self._routing_field}.keyword": self._routing}})
response = self._client.search(index=self._collection_name, body=query) response = self._client.search(index=self._collection_name, body=query)
if response["hits"]["hits"]: if response["hits"]["hits"]:
return [hit["_id"] for hit in response["hits"]["hits"]] return [hit["_id"] for hit in response["hits"]["hits"]]
return None return None


def delete_by_metadata_field(self, key: str, value: str): def delete_by_metadata_field(self, key: str, value: str):
query_str = {"query": {"match": {f"metadata.{key}": f"{value}"}}}
results = self._client.search(index=self._collection_name, body=query_str)
ids = [hit["_id"] for hit in results["hits"]["hits"]]
ids = self.get_ids_by_metadata_field(key, value)
if ids: if ids:
self.delete_by_ids(ids) self.delete_by_ids(ids)


def delete_by_ids(self, ids: list[str]) -> None: def delete_by_ids(self, ids: list[str]) -> None:
params = {}
if self._using_ugc:
params["routing"] = self._routing
for id in ids: for id in ids:
if self._client.exists(index=self._collection_name, id=id):
self._client.delete(index=self._collection_name, id=id)
if self._client.exists(index=self._collection_name, id=id, params=params):
params = {}
if self._using_ugc:
params["routing"] = self._routing
self._client.delete(index=self._collection_name, id=id, params=params)
self.refresh()
else: else:
logger.warning(f"DELETE BY ID: ID {id} does not exist in the index.") logger.warning(f"DELETE BY ID: ID {id} does not exist in the index.")


def delete(self) -> None: def delete(self) -> None:
try:
if self._using_ugc:
routing_filter_query = {
"query": {"bool": {"must": [{"term": {f"{self._routing_field}.keyword": self._routing}}]}}
}
self._client.delete_by_query(self._collection_name, body=routing_filter_query)
self.refresh()
else:
if self._client.indices.exists(index=self._collection_name): if self._client.indices.exists(index=self._collection_name):
self._client.indices.delete(index=self._collection_name, params={"timeout": 60}) self._client.indices.delete(index=self._collection_name, params={"timeout": 60})
logger.info("Delete index success") logger.info("Delete index success")
else: else:
logger.warning(f"Index '{self._collection_name}' does not exist. No deletion performed.") logger.warning(f"Index '{self._collection_name}' does not exist. No deletion performed.")
except Exception as e:
logger.exception(f"Error occurred while deleting the index: {self._collection_name}")
raise e


def text_exists(self, id: str) -> bool: def text_exists(self, id: str) -> bool:
try: try:
self._client.get(index=self._collection_name, id=id)
params = {}
if self._using_ugc:
params["routing"] = self._routing
self._client.get(index=self._collection_name, id=id, params=params)
return True return True
except: except:
return False return False


def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
# Make sure query_vector is a list
if not isinstance(query_vector, list): if not isinstance(query_vector, list):
raise ValueError("query_vector should be a list of floats") raise ValueError("query_vector should be a list of floats")


# Check whether query_vector is a floating-point number list
if not all(isinstance(x, float) for x in query_vector): if not all(isinstance(x, float) for x in query_vector):
raise ValueError("All elements in query_vector should be floats") raise ValueError("All elements in query_vector should be floats")


top_k = kwargs.get("top_k", 10) top_k = kwargs.get("top_k", 10)
query = default_vector_search_query(query_vector=query_vector, k=top_k, **kwargs) query = default_vector_search_query(query_vector=query_vector, k=top_k, **kwargs)
try: try:
response = self._client.search(index=self._collection_name, body=query)
params = {}
if self._using_ugc:
params["routing"] = self._routing
response = self._client.search(index=self._collection_name, body=query, params=params)
except Exception as e: except Exception as e:
logger.exception(f"Error executing vector search, query: {query}") logger.exception(f"Error executing vector search, query: {query}")
raise raise
minimum_should_match = kwargs.get("minimum_should_match", 0) minimum_should_match = kwargs.get("minimum_should_match", 0)
top_k = kwargs.get("top_k", 10) top_k = kwargs.get("top_k", 10)
filters = kwargs.get("filter") filters = kwargs.get("filter")
routing = kwargs.get("routing")
routing = self._routing
full_text_query = default_text_search_query( full_text_query = default_text_search_query(
query_text=query, query_text=query,
k=top_k, k=top_k,
minimum_should_match=minimum_should_match, minimum_should_match=minimum_should_match,
filters=filters, filters=filters,
routing=routing, routing=routing,
routing_field=self._routing_field,
) )
response = self._client.search(index=self._collection_name, body=full_text_query) response = self._client.search(index=self._collection_name, body=full_text_query)
docs = [] docs = []
logger.info(f"Collection {self._collection_name} already exists.") logger.info(f"Collection {self._collection_name} already exists.")
return return
if self._client.indices.exists(index=self._collection_name): if self._client.indices.exists(index=self._collection_name):
logger.info("{self._collection_name.lower()} already exists.")
logger.info(f"{self._collection_name.lower()} already exists.")
redis_client.set(collection_exist_cache_key, 1, ex=3600)
return return
if len(self.kwargs) == 0 and len(kwargs) != 0: if len(self.kwargs) == 0 and len(kwargs) != 0:
self.kwargs = copy.deepcopy(kwargs) self.kwargs = copy.deepcopy(kwargs)
vector_field = kwargs.pop("vector_field", Field.VECTOR.value) vector_field = kwargs.pop("vector_field", Field.VECTOR.value)
shards = kwargs.pop("shards", 2)
shards = kwargs.pop("shards", 4)


engine = kwargs.pop("engine", "lvector") engine = kwargs.pop("engine", "lvector")
method_name = kwargs.pop("method_name", "hnsw")
method_name = kwargs.pop("method_name", dify_config.DEFAULT_INDEX_TYPE)
space_type = kwargs.pop("space_type", dify_config.DEFAULT_DISTANCE_TYPE)
data_type = kwargs.pop("data_type", "float") data_type = kwargs.pop("data_type", "float")
space_type = kwargs.pop("space_type", "cosinesimil")


hnsw_m = kwargs.pop("hnsw_m", 24) hnsw_m = kwargs.pop("hnsw_m", 24)
hnsw_ef_construction = kwargs.pop("hnsw_ef_construction", 500) hnsw_ef_construction = kwargs.pop("hnsw_ef_construction", 500)
mapping = default_text_mapping( mapping = default_text_mapping(
dimension, dimension,
method_name, method_name,
space_type=space_type,
shards=shards, shards=shards,
engine=engine, engine=engine,
data_type=data_type, data_type=data_type,
space_type=space_type,
vector_field=vector_field, vector_field=vector_field,
hnsw_m=hnsw_m, hnsw_m=hnsw_m,
hnsw_ef_construction=hnsw_ef_construction, hnsw_ef_construction=hnsw_ef_construction,
centroids_hnsw_m=centroids_hnsw_m, centroids_hnsw_m=centroids_hnsw_m,
centroids_hnsw_ef_construct=centroids_hnsw_ef_construct, centroids_hnsw_ef_construct=centroids_hnsw_ef_construct,
centroids_hnsw_ef_search=centroids_hnsw_ef_search, centroids_hnsw_ef_search=centroids_hnsw_ef_search,
using_ugc=self._using_ugc,
**kwargs, **kwargs,
) )
self._client.indices.create(index=self._collection_name.lower(), body=mapping) self._client.indices.create(index=self._collection_name.lower(), body=mapping)




def default_text_mapping(dimension: int, method_name: str, **kwargs: Any) -> dict: def default_text_mapping(dimension: int, method_name: str, **kwargs: Any) -> dict:
routing_field = kwargs.get("routing_field")
excludes_from_source = kwargs.get("excludes_from_source") excludes_from_source = kwargs.get("excludes_from_source")
analyzer = kwargs.get("analyzer", "ik_max_word") analyzer = kwargs.get("analyzer", "ik_max_word")
text_field = kwargs.get("text_field", Field.CONTENT_KEY.value) text_field = kwargs.get("text_field", Field.CONTENT_KEY.value)
engine = kwargs["engine"] engine = kwargs["engine"]
shard = kwargs["shards"] shard = kwargs["shards"]
space_type = kwargs["space_type"]
space_type = kwargs.get("space_type")
if space_type is None:
if method_name == "hnsw":
space_type = "l2"
else:
space_type = "cosine"
data_type = kwargs["data_type"] data_type = kwargs["data_type"]
vector_field = kwargs.get("vector_field", Field.VECTOR.value) vector_field = kwargs.get("vector_field", Field.VECTOR.value)
using_ugc = kwargs.get("using_ugc", False)


if method_name == "ivfpq": if method_name == "ivfpq":
ivfpq_m = kwargs["ivfpq_m"] ivfpq_m = kwargs["ivfpq_m"]
if excludes_from_source: if excludes_from_source:
mapping["mappings"]["_source"] = {"excludes": excludes_from_source} # e.g. {"excludes": ["vector_field"]} mapping["mappings"]["_source"] = {"excludes": excludes_from_source} # e.g. {"excludes": ["vector_field"]}


if method_name == "ivfpq" and routing_field is not None:
if using_ugc and method_name == "ivfpq":
mapping["settings"]["index"]["knn_routing"] = True mapping["settings"]["index"]["knn_routing"] = True
mapping["settings"]["index"]["knn.offline.construction"] = True mapping["settings"]["index"]["knn.offline.construction"] = True

if method_name == "flat" and routing_field is not None:
elif using_ugc and method_name == "hnsw" or using_ugc and method_name == "flat":
mapping["settings"]["index"]["knn_routing"] = True mapping["settings"]["index"]["knn_routing"] = True

return mapping return mapping




minimum_should_match: int = 0, minimum_should_match: int = 0,
filters: Optional[list[dict]] = None, filters: Optional[list[dict]] = None,
routing: Optional[str] = None, routing: Optional[str] = None,
routing_field: Optional[str] = None,
**kwargs, **kwargs,
) -> dict: ) -> dict:
if routing is not None: if routing is not None:
routing_field = kwargs.get("routing_field", "routing_field")
query_clause = { query_clause = {
"bool": {
"must": [{"match": {text_field: query_text}}, {"term": {f"metadata.{routing_field}.keyword": routing}}]
}
"bool": {"must": [{"match": {text_field: query_text}}, {"term": {f"{routing_field}.keyword": routing}}]}
} }
else: else:
query_clause = {"match": {text_field: query_text}} query_clause = {"match": {text_field: query_text}}


class LindormVectorStoreFactory(AbstractVectorFactory): class LindormVectorStoreFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> LindormVectorStore: def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> LindormVectorStore:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
collection_name = class_prefix
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.LINDORM, collection_name))
lindorm_config = LindormVectorStoreConfig( lindorm_config = LindormVectorStoreConfig(
hosts=dify_config.LINDORM_URL, hosts=dify_config.LINDORM_URL,
username=dify_config.LINDORM_USERNAME, username=dify_config.LINDORM_USERNAME,
password=dify_config.LINDORM_PASSWORD, password=dify_config.LINDORM_PASSWORD,
using_ugc=dify_config.USING_UGC_INDEX,
) )
return LindormVectorStore(collection_name, lindorm_config)
using_ugc = dify_config.USING_UGC_INDEX
routing_value = None
if dataset.index_struct:
if using_ugc:
dimension = dataset.index_struct_dict["dimension"]
index_type = dataset.index_struct_dict["index_type"]
distance_type = dataset.index_struct_dict["distance_type"]
index_name = f"{UGC_INDEX_PREFIX}_{dimension}_{index_type}_{distance_type}"
routing_value = dataset.index_struct_dict["vector_store"]["class_prefix"]
else:
index_name = dataset.index_struct_dict["vector_store"]["class_prefix"]
else:
embedding_vector = embeddings.embed_query("hello word")
dimension = len(embedding_vector)
index_type = dify_config.DEFAULT_INDEX_TYPE
distance_type = dify_config.DEFAULT_DISTANCE_TYPE
class_prefix = Dataset.gen_collection_name_by_id(dataset.id)
index_struct_dict = {
"type": VectorType.LINDORM,
"vector_store": {"class_prefix": class_prefix},
"index_type": index_type,
"dimension": dimension,
"distance_type": distance_type,
}
dataset.index_struct = json.dumps(index_struct_dict)
if using_ugc:
index_name = f"{UGC_INDEX_PREFIX}_{dimension}_{index_type}_{distance_type}"
routing_value = class_prefix
else:
index_name = class_prefix
return LindormVectorStore(index_name, lindorm_config, routing_value=routing_value)

+ 26
- 3
api/tests/integration_tests/vdb/lindorm/test_lindorm.py View File





class Config: class Config:
SEARCH_ENDPOINT = env.str("SEARCH_ENDPOINT", "http://ld-*************-proxy-search-pub.lindorm.aliyuncs.com:30070")
SEARCH_ENDPOINT = env.str("SEARCH_ENDPOINT", "http://ld-************-proxy-search-pub.lindorm.aliyuncs.com:30070")
SEARCH_USERNAME = env.str("SEARCH_USERNAME", "ADMIN") SEARCH_USERNAME = env.str("SEARCH_USERNAME", "ADMIN")
SEARCH_PWD = env.str("SEARCH_PWD", "PWD")
SEARCH_PWD = env.str("SEARCH_PWD", "ADMIN")
USING_UGC = env.bool("USING_UGC", True)




class TestLindormVectorStore(AbstractVectorTest): class TestLindormVectorStore(AbstractVectorTest):
assert ids[0] == self.example_doc_id assert ids[0] == self.example_doc_id




def test_lindorm_vector(setup_mock_redis):
class TestLindormVectorStoreUGC(AbstractVectorTest):
def __init__(self):
super().__init__()
self.vector = LindormVectorStore(
collection_name="ugc_index_test",
config=LindormVectorStoreConfig(
hosts=Config.SEARCH_ENDPOINT,
username=Config.SEARCH_USERNAME,
password=Config.SEARCH_PWD,
using_ugc=Config.USING_UGC,
),
routing_value=self.collection_name,
)

def get_ids_by_metadata_field(self):
ids = self.vector.get_ids_by_metadata_field(key="doc_id", value=self.example_doc_id)
assert ids is not None
assert len(ids) == 1
assert ids[0] == self.example_doc_id


def test_lindorm_vector_ugc(setup_mock_redis):
TestLindormVectorStore().run_all_tests() TestLindormVectorStore().run_all_tests()
TestLindormVectorStoreUGC().run_all_tests()

Loading…
Cancel
Save