|
|
|
@@ -25,6 +25,7 @@ class OpenGaussConfig(BaseModel): |
|
|
|
database: str |
|
|
|
min_connection: int |
|
|
|
max_connection: int |
|
|
|
enable_pq: bool = False # Enable PQ acceleration |
|
|
|
|
|
|
|
@model_validator(mode="before") |
|
|
|
@classmethod |
|
|
|
@@ -57,6 +58,11 @@ CREATE TABLE IF NOT EXISTS {table_name} ( |
|
|
|
); |
|
|
|
""" |
|
|
|
|
|
|
|
SQL_CREATE_INDEX_PQ = """ |
|
|
|
CREATE INDEX IF NOT EXISTS embedding_{table_name}_pq_idx ON {table_name} |
|
|
|
USING hnsw (embedding vector_cosine_ops) WITH (m = 16, ef_construction = 64, enable_pq=on, pq_m={pq_m}); |
|
|
|
""" |
|
|
|
|
|
|
|
SQL_CREATE_INDEX = """ |
|
|
|
CREATE INDEX IF NOT EXISTS embedding_cosine_{table_name}_idx ON {table_name} |
|
|
|
USING hnsw (embedding vector_cosine_ops) WITH (m = 16, ef_construction = 64); |
|
|
|
@@ -68,6 +74,7 @@ class OpenGauss(BaseVector): |
|
|
|
super().__init__(collection_name) |
|
|
|
self.pool = self._create_connection_pool(config) |
|
|
|
self.table_name = f"embedding_{collection_name}" |
|
|
|
self.pq_enabled = config.enable_pq |
|
|
|
|
|
|
|
def get_type(self) -> str: |
|
|
|
return VectorType.OPENGAUSS |
|
|
|
@@ -97,7 +104,26 @@ class OpenGauss(BaseVector): |
|
|
|
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): |
|
|
|
dimension = len(embeddings[0]) |
|
|
|
self._create_collection(dimension) |
|
|
|
return self.add_texts(texts, embeddings) |
|
|
|
self.add_texts(texts, embeddings) |
|
|
|
self._create_index(dimension) |
|
|
|
|
|
|
|
def _create_index(self, dimension: int): |
|
|
|
index_cache_key = f"vector_index_{self._collection_name}" |
|
|
|
lock_name = f"{index_cache_key}_lock" |
|
|
|
with redis_client.lock(lock_name, timeout=60): |
|
|
|
index_exist_cache_key = f"vector_index_{self._collection_name}" |
|
|
|
if redis_client.get(index_exist_cache_key): |
|
|
|
return |
|
|
|
|
|
|
|
with self._get_cursor() as cur: |
|
|
|
if dimension <= 2000: |
|
|
|
if self.pq_enabled: |
|
|
|
cur.execute(SQL_CREATE_INDEX_PQ.format(table_name=self.table_name, pq_m=int(dimension / 4))) |
|
|
|
cur.execute("SET hnsw_earlystop_threshold = 320") |
|
|
|
|
|
|
|
if not self.pq_enabled: |
|
|
|
cur.execute(SQL_CREATE_INDEX.format(table_name=self.table_name)) |
|
|
|
redis_client.set(index_exist_cache_key, 1, ex=3600) |
|
|
|
|
|
|
|
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): |
|
|
|
values = [] |
|
|
|
@@ -211,8 +237,6 @@ class OpenGauss(BaseVector): |
|
|
|
|
|
|
|
with self._get_cursor() as cur: |
|
|
|
cur.execute(SQL_CREATE_TABLE.format(table_name=self.table_name, dimension=dimension)) |
|
|
|
if dimension <= 2000: |
|
|
|
cur.execute(SQL_CREATE_INDEX.format(table_name=self.table_name)) |
|
|
|
redis_client.set(collection_exist_cache_key, 1, ex=3600) |
|
|
|
|
|
|
|
|
|
|
|
@@ -236,5 +260,6 @@ class OpenGaussFactory(AbstractVectorFactory): |
|
|
|
database=dify_config.OPENGAUSS_DATABASE or "dify", |
|
|
|
min_connection=dify_config.OPENGAUSS_MIN_CONNECTION, |
|
|
|
max_connection=dify_config.OPENGAUSS_MAX_CONNECTION, |
|
|
|
enable_pq=dify_config.OPENGAUSS_ENABLE_PQ or False, |
|
|
|
), |
|
|
|
) |