|
|
|
@@ -0,0 +1,169 @@ |
|
|
|
import logging |
|
|
|
from typing import Any |
|
|
|
|
|
|
|
from pgvecto_rs.sdk import PGVectoRs, Record |
|
|
|
from pydantic import BaseModel, root_validator |
|
|
|
from sqlalchemy import text as sql_text |
|
|
|
from sqlalchemy.orm import Session |
|
|
|
|
|
|
|
from core.rag.datasource.vdb.vector_base import BaseVector |
|
|
|
from core.rag.models.document import Document |
|
|
|
from extensions.ext_redis import redis_client |
|
|
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
class RelytConfig(BaseModel): |
|
|
|
host: str |
|
|
|
port: int |
|
|
|
user: str |
|
|
|
password: str |
|
|
|
database: str |
|
|
|
|
|
|
|
@root_validator() |
|
|
|
def validate_config(cls, values: dict) -> dict: |
|
|
|
if not values['host']: |
|
|
|
raise ValueError("config RELYT_HOST is required") |
|
|
|
if not values['port']: |
|
|
|
raise ValueError("config RELYT_PORT is required") |
|
|
|
if not values['user']: |
|
|
|
raise ValueError("config RELYT_USER is required") |
|
|
|
if not values['password']: |
|
|
|
raise ValueError("config RELYT_PASSWORD is required") |
|
|
|
if not values['database']: |
|
|
|
raise ValueError("config RELYT_DATABASE is required") |
|
|
|
return values |
|
|
|
|
|
|
|
|
|
|
|
class RelytVector(BaseVector): |
|
|
|
|
|
|
|
def __init__(self, collection_name: str, config: RelytConfig, dim: int): |
|
|
|
super().__init__(collection_name) |
|
|
|
self._client_config = config |
|
|
|
self._url = f"postgresql+psycopg2://{config.user}:{config.password}@{config.host}:{config.port}/{config.database}" |
|
|
|
self._client = PGVectoRs( |
|
|
|
db_url=self._url, |
|
|
|
collection_name=self._collection_name, |
|
|
|
dimension=dim |
|
|
|
) |
|
|
|
self._fields = [] |
|
|
|
|
|
|
|
def get_type(self) -> str: |
|
|
|
return 'relyt' |
|
|
|
|
|
|
|
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): |
|
|
|
index_params = {} |
|
|
|
metadatas = [d.metadata for d in texts] |
|
|
|
self.create_collection(len(embeddings[0])) |
|
|
|
self.add_texts(texts, embeddings) |
|
|
|
|
|
|
|
def create_collection(self, dimension: int): |
|
|
|
lock_name = 'vector_indexing_lock_{}'.format(self._collection_name) |
|
|
|
with redis_client.lock(lock_name, timeout=20): |
|
|
|
collection_exist_cache_key = 'vector_indexing_{}'.format(self._collection_name) |
|
|
|
if redis_client.get(collection_exist_cache_key): |
|
|
|
return |
|
|
|
index_name = f"{self._collection_name}_embedding_index" |
|
|
|
with Session(self._client._engine) as session: |
|
|
|
drop_statement = sql_text(f"DROP TABLE IF EXISTS collection_{self._collection_name}") |
|
|
|
session.execute(drop_statement) |
|
|
|
create_statement = sql_text(f""" |
|
|
|
CREATE TABLE IF NOT EXISTS collection_{self._collection_name} ( |
|
|
|
id UUID PRIMARY KEY, |
|
|
|
text TEXT NOT NULL, |
|
|
|
meta JSONB NOT NULL, |
|
|
|
embedding vector({dimension}) NOT NULL |
|
|
|
) using heap; |
|
|
|
""") |
|
|
|
session.execute(create_statement) |
|
|
|
index_statement = sql_text(f""" |
|
|
|
CREATE INDEX {index_name} |
|
|
|
ON collection_{self._collection_name} USING vectors(embedding vector_l2_ops) |
|
|
|
WITH (options = $$ |
|
|
|
optimizing.optimizing_threads = 30 |
|
|
|
segment.max_growing_segment_size = 2000 |
|
|
|
segment.max_sealed_segment_size = 30000000 |
|
|
|
[indexing.hnsw] |
|
|
|
m=30 |
|
|
|
ef_construction=500 |
|
|
|
$$); |
|
|
|
""") |
|
|
|
session.execute(index_statement) |
|
|
|
session.commit() |
|
|
|
redis_client.set(collection_exist_cache_key, 1, ex=3600) |
|
|
|
|
|
|
|
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): |
|
|
|
records = [Record.from_text(d.page_content, e, d.metadata) for d, e in zip(documents, embeddings)] |
|
|
|
pks = [str(r.id) for r in records] |
|
|
|
self._client.insert(records) |
|
|
|
return pks |
|
|
|
|
|
|
|
def delete_by_document_id(self, document_id: str): |
|
|
|
ids = self.get_ids_by_metadata_field('document_id', document_id) |
|
|
|
if ids: |
|
|
|
self._client.delete_by_ids(ids) |
|
|
|
|
|
|
|
def get_ids_by_metadata_field(self, key: str, value: str): |
|
|
|
result = None |
|
|
|
with Session(self._client._engine) as session: |
|
|
|
select_statement = sql_text( |
|
|
|
f"SELECT id FROM collection_{self._collection_name} WHERE meta->>'{key}' = '{value}'; " |
|
|
|
) |
|
|
|
result = session.execute(select_statement).fetchall() |
|
|
|
if result: |
|
|
|
return [item[0] for item in result] |
|
|
|
else: |
|
|
|
return None |
|
|
|
|
|
|
|
def delete_by_metadata_field(self, key: str, value: str): |
|
|
|
|
|
|
|
ids = self.get_ids_by_metadata_field(key, value) |
|
|
|
if ids: |
|
|
|
self._client.delete_by_ids(ids) |
|
|
|
|
|
|
|
def delete_by_ids(self, doc_ids: list[str]) -> None: |
|
|
|
with Session(self._client._engine) as session: |
|
|
|
select_statement = sql_text( |
|
|
|
f"SELECT id FROM collection_{self._collection_name} WHERE meta->>'doc_id' in ('{doc_ids}'); " |
|
|
|
) |
|
|
|
result = session.execute(select_statement).fetchall() |
|
|
|
if result: |
|
|
|
ids = [item[0] for item in result] |
|
|
|
self._client.delete_by_ids(ids) |
|
|
|
|
|
|
|
def delete(self) -> None: |
|
|
|
with Session(self._client._engine) as session: |
|
|
|
session.execute(sql_text(f"DROP TABLE IF EXISTS collection_{self._collection_name}")) |
|
|
|
session.commit() |
|
|
|
|
|
|
|
def text_exists(self, id: str) -> bool: |
|
|
|
with Session(self._client._engine) as session: |
|
|
|
select_statement = sql_text( |
|
|
|
f"SELECT id FROM collection_{self._collection_name} WHERE meta->>'doc_id' = '{id}' limit 1; " |
|
|
|
) |
|
|
|
result = session.execute(select_statement).fetchall() |
|
|
|
return len(result) > 0 |
|
|
|
|
|
|
|
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: |
|
|
|
from pgvecto_rs.sdk import filters |
|
|
|
filter_condition = filters.meta_contains(kwargs.get('filter')) |
|
|
|
results = self._client.search( |
|
|
|
top_k=int(kwargs.get('top_k')), |
|
|
|
embedding=query_vector, |
|
|
|
filter=filter_condition |
|
|
|
) |
|
|
|
|
|
|
|
# Organize results. |
|
|
|
docs = [] |
|
|
|
for record, dis in results: |
|
|
|
metadata = record.meta |
|
|
|
metadata['score'] = dis |
|
|
|
score_threshold = kwargs.get('score_threshold') if kwargs.get('score_threshold') else 0.0 |
|
|
|
if dis > score_threshold: |
|
|
|
doc = Document(page_content=record.text, |
|
|
|
metadata=metadata) |
|
|
|
docs.append(doc) |
|
|
|
return docs |
|
|
|
|
|
|
|
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: |
|
|
|
# milvus/zilliz/relyt doesn't support bm25 search |
|
|
|
return [] |