|
|
|
@@ -0,0 +1,214 @@ |
|
|
|
import json |
|
|
|
import logging |
|
|
|
from typing import Any |
|
|
|
|
|
|
|
import sqlalchemy |
|
|
|
from pydantic import BaseModel, root_validator |
|
|
|
from sqlalchemy import JSON, TEXT, Column, DateTime, String, Table, create_engine, insert |
|
|
|
from sqlalchemy import text as sql_text |
|
|
|
from sqlalchemy.orm import Session, declarative_base |
|
|
|
|
|
|
|
from core.rag.datasource.vdb.vector_base import BaseVector |
|
|
|
from core.rag.models.document import Document |
|
|
|
from extensions.ext_redis import redis_client |
|
|
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
|
|
|
class TiDBVectorConfig(BaseModel): |
|
|
|
host: str |
|
|
|
port: int |
|
|
|
user: str |
|
|
|
password: str |
|
|
|
database: str |
|
|
|
|
|
|
|
@root_validator() |
|
|
|
def validate_config(cls, values: dict) -> dict: |
|
|
|
if not values['host']: |
|
|
|
raise ValueError("config TIDB_VECTOR_HOST is required") |
|
|
|
if not values['port']: |
|
|
|
raise ValueError("config TIDB_VECTOR_PORT is required") |
|
|
|
if not values['user']: |
|
|
|
raise ValueError("config TIDB_VECTOR_USER is required") |
|
|
|
if not values['password']: |
|
|
|
raise ValueError("config TIDB_VECTOR_PASSWORD is required") |
|
|
|
if not values['database']: |
|
|
|
raise ValueError("config TIDB_VECTOR_DATABASE is required") |
|
|
|
return values |
|
|
|
|
|
|
|
|
|
|
|
class TiDBVector(BaseVector): |
|
|
|
|
|
|
|
def _table(self, dim: int) -> Table: |
|
|
|
from tidb_vector.sqlalchemy import VectorType |
|
|
|
return Table( |
|
|
|
self._collection_name, |
|
|
|
self._orm_base.metadata, |
|
|
|
Column('id', String(36), primary_key=True, nullable=False), |
|
|
|
Column("vector", VectorType(dim), nullable=False, comment="" if self._distance_func is None else f"hnsw(distance={self._distance_func})"), |
|
|
|
Column("text", TEXT, nullable=False), |
|
|
|
Column("meta", JSON, nullable=False), |
|
|
|
Column("create_time", DateTime, server_default=sqlalchemy.text("CURRENT_TIMESTAMP")), |
|
|
|
Column("update_time", DateTime, server_default=sqlalchemy.text("CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP")), |
|
|
|
extend_existing=True |
|
|
|
) |
|
|
|
|
|
|
|
def __init__(self, collection_name: str, config: TiDBVectorConfig, distance_func: str = 'cosine'): |
|
|
|
super().__init__(collection_name) |
|
|
|
self._client_config = config |
|
|
|
self._url = (f"mysql+pymysql://{config.user}:{config.password}@{config.host}:{config.port}/{config.database}?" |
|
|
|
f"ssl_verify_cert=true&ssl_verify_identity=true") |
|
|
|
self._distance_func = distance_func.lower() |
|
|
|
self._engine = create_engine(self._url) |
|
|
|
self._orm_base = declarative_base() |
|
|
|
self._dimension = 1536 |
|
|
|
|
|
|
|
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): |
|
|
|
logger.info("create collection and add texts, collection_name: " + self._collection_name) |
|
|
|
self._create_collection(len(embeddings[0])) |
|
|
|
self.add_texts(texts, embeddings) |
|
|
|
self._dimension = len(embeddings[0]) |
|
|
|
pass |
|
|
|
|
|
|
|
def _create_collection(self, dimension: int): |
|
|
|
logger.info("_create_collection, collection_name " + self._collection_name) |
|
|
|
lock_name = 'vector_indexing_lock_{}'.format(self._collection_name) |
|
|
|
with redis_client.lock(lock_name, timeout=20): |
|
|
|
collection_exist_cache_key = 'vector_indexing_{}'.format(self._collection_name) |
|
|
|
if redis_client.get(collection_exist_cache_key): |
|
|
|
return |
|
|
|
with Session(self._engine) as session: |
|
|
|
session.begin() |
|
|
|
drop_statement = sql_text(f"""DROP TABLE IF EXISTS {self._collection_name}; """) |
|
|
|
session.execute(drop_statement) |
|
|
|
create_statement = sql_text(f""" |
|
|
|
CREATE TABLE IF NOT EXISTS {self._collection_name} ( |
|
|
|
id CHAR(36) PRIMARY KEY, |
|
|
|
text TEXT NOT NULL, |
|
|
|
meta JSON NOT NULL, |
|
|
|
vector VECTOR<FLOAT>({dimension}) NOT NULL COMMENT "hnsw(distance={self._distance_func})", |
|
|
|
create_time DATETIME DEFAULT CURRENT_TIMESTAMP, |
|
|
|
update_time DATETIME DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP |
|
|
|
); |
|
|
|
""") |
|
|
|
session.execute(create_statement) |
|
|
|
# tidb vector not support 'CREATE/ADD INDEX' now |
|
|
|
session.commit() |
|
|
|
redis_client.set(collection_exist_cache_key, 1, ex=3600) |
|
|
|
|
|
|
|
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): |
|
|
|
table = self._table(len(embeddings[0])) |
|
|
|
ids = self._get_uuids(documents) |
|
|
|
metas = [d.metadata for d in documents] |
|
|
|
texts = [d.page_content for d in documents] |
|
|
|
|
|
|
|
chunks_table_data = [] |
|
|
|
with self._engine.connect() as conn: |
|
|
|
with conn.begin(): |
|
|
|
for id, text, meta, embedding in zip( |
|
|
|
ids, texts, metas, embeddings |
|
|
|
): |
|
|
|
chunks_table_data.append({"id": id, "vector": embedding, "text": text, "meta": meta}) |
|
|
|
|
|
|
|
# Execute the batch insert when the batch size is reached |
|
|
|
if len(chunks_table_data) == 500: |
|
|
|
conn.execute(insert(table).values(chunks_table_data)) |
|
|
|
# Clear the chunks_table_data list for the next batch |
|
|
|
chunks_table_data.clear() |
|
|
|
|
|
|
|
# Insert any remaining records that didn't make up a full batch |
|
|
|
if chunks_table_data: |
|
|
|
conn.execute(insert(table).values(chunks_table_data)) |
|
|
|
return ids |
|
|
|
|
|
|
|
def text_exists(self, id: str) -> bool: |
|
|
|
result = self.get_ids_by_metadata_field('doc_id', id) |
|
|
|
return len(result) > 0 |
|
|
|
|
|
|
|
def delete_by_ids(self, ids: list[str]) -> None: |
|
|
|
with Session(self._engine) as session: |
|
|
|
ids_str = ','.join(f"'{doc_id}'" for doc_id in ids) |
|
|
|
select_statement = sql_text( |
|
|
|
f"""SELECT id FROM {self._collection_name} WHERE meta->>'$.doc_id' in ({ids_str}); """ |
|
|
|
) |
|
|
|
result = session.execute(select_statement).fetchall() |
|
|
|
if result: |
|
|
|
ids = [item[0] for item in result] |
|
|
|
self._delete_by_ids(ids) |
|
|
|
|
|
|
|
def _delete_by_ids(self, ids: list[str]) -> bool: |
|
|
|
if ids is None: |
|
|
|
raise ValueError("No ids provided to delete.") |
|
|
|
table = self._table(self._dimension) |
|
|
|
try: |
|
|
|
with self._engine.connect() as conn: |
|
|
|
with conn.begin(): |
|
|
|
delete_condition = table.c.id.in_(ids) |
|
|
|
conn.execute(table.delete().where(delete_condition)) |
|
|
|
return True |
|
|
|
except Exception as e: |
|
|
|
print("Delete operation failed:", str(e)) |
|
|
|
return False |
|
|
|
|
|
|
|
def delete_by_document_id(self, document_id: str): |
|
|
|
ids = self.get_ids_by_metadata_field('document_id', document_id) |
|
|
|
if ids: |
|
|
|
self._delete_by_ids(ids) |
|
|
|
|
|
|
|
def get_ids_by_metadata_field(self, key: str, value: str): |
|
|
|
with Session(self._engine) as session: |
|
|
|
select_statement = sql_text( |
|
|
|
f"""SELECT id FROM {self._collection_name} WHERE meta->>'$.{key}' = '{value}'; """ |
|
|
|
) |
|
|
|
result = session.execute(select_statement).fetchall() |
|
|
|
if result: |
|
|
|
return [item[0] for item in result] |
|
|
|
else: |
|
|
|
return None |
|
|
|
|
|
|
|
def delete_by_metadata_field(self, key: str, value: str) -> None: |
|
|
|
ids = self.get_ids_by_metadata_field(key, value) |
|
|
|
if ids: |
|
|
|
self._delete_by_ids(ids) |
|
|
|
|
|
|
|
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: |
|
|
|
top_k = kwargs.get("top_k", 5) |
|
|
|
score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0 |
|
|
|
filter = kwargs.get('filter') |
|
|
|
distance = 1 - score_threshold |
|
|
|
|
|
|
|
query_vector_str = ", ".join(format(x) for x in query_vector) |
|
|
|
query_vector_str = "[" + query_vector_str + "]" |
|
|
|
logger.debug(f"_collection_name: {self._collection_name}, score_threshold: {score_threshold}, distance: {distance}") |
|
|
|
|
|
|
|
docs = [] |
|
|
|
if self._distance_func == 'l2': |
|
|
|
tidb_func = 'Vec_l2_distance' |
|
|
|
elif self._distance_func == 'l2': |
|
|
|
tidb_func = 'Vec_Cosine_distance' |
|
|
|
else: |
|
|
|
tidb_func = 'Vec_Cosine_distance' |
|
|
|
|
|
|
|
with Session(self._engine) as session: |
|
|
|
select_statement = sql_text( |
|
|
|
f"""SELECT meta, text FROM ( |
|
|
|
SELECT meta, text, {tidb_func}(vector, "{query_vector_str}") as distance |
|
|
|
FROM {self._collection_name} |
|
|
|
ORDER BY distance |
|
|
|
LIMIT {top_k} |
|
|
|
) t WHERE distance < {distance};""" |
|
|
|
) |
|
|
|
res = session.execute(select_statement) |
|
|
|
results = [(row[0], row[1]) for row in res] |
|
|
|
for meta, text in results: |
|
|
|
docs.append(Document(page_content=text, metadata=json.loads(meta))) |
|
|
|
return docs |
|
|
|
|
|
|
|
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: |
|
|
|
# tidb doesn't support bm25 search |
|
|
|
return [] |
|
|
|
|
|
|
|
def delete(self) -> None: |
|
|
|
with Session(self._engine) as session: |
|
|
|
session.execute(sql_text(f"""DROP TABLE IF EXISTS {self._collection_name};""")) |
|
|
|
session.commit() |