|
|
|
@@ -1,7 +1,9 @@ |
|
|
|
import json |
|
|
|
import logging |
|
|
|
import queue |
|
|
|
import re |
|
|
|
import threading |
|
|
|
import time |
|
|
|
import uuid |
|
|
|
from typing import TYPE_CHECKING, Any, Optional |
|
|
|
|
|
|
|
@@ -67,6 +69,243 @@ class ClickzettaConfig(BaseModel): |
|
|
|
return values |
|
|
|
|
|
|
|
|
|
|
|
class ClickzettaConnectionPool: |
|
|
|
""" |
|
|
|
Global connection pool for ClickZetta connections. |
|
|
|
Manages connection reuse across ClickzettaVector instances. |
|
|
|
""" |
|
|
|
|
|
|
|
_instance: Optional["ClickzettaConnectionPool"] = None |
|
|
|
_lock = threading.Lock() |
|
|
|
|
|
|
|
def __init__(self): |
|
|
|
self._pools: dict[str, list[tuple[Connection, float]]] = {} # config_key -> [(connection, last_used_time)] |
|
|
|
self._pool_locks: dict[str, threading.Lock] = {} |
|
|
|
self._max_pool_size = 5 # Maximum connections per configuration |
|
|
|
self._connection_timeout = 300 # 5 minutes timeout |
|
|
|
self._cleanup_thread: Optional[threading.Thread] = None |
|
|
|
self._shutdown = False |
|
|
|
self._start_cleanup_thread() |
|
|
|
|
|
|
|
@classmethod |
|
|
|
def get_instance(cls) -> "ClickzettaConnectionPool": |
|
|
|
"""Get singleton instance of connection pool.""" |
|
|
|
if cls._instance is None: |
|
|
|
with cls._lock: |
|
|
|
if cls._instance is None: |
|
|
|
cls._instance = cls() |
|
|
|
return cls._instance |
|
|
|
|
|
|
|
def _get_config_key(self, config: ClickzettaConfig) -> str: |
|
|
|
"""Generate unique key for connection configuration.""" |
|
|
|
return ( |
|
|
|
f"{config.username}:{config.instance}:{config.service}:" |
|
|
|
f"{config.workspace}:{config.vcluster}:{config.schema_name}" |
|
|
|
) |
|
|
|
|
|
|
|
def _create_connection(self, config: ClickzettaConfig) -> "Connection": |
|
|
|
"""Create a new ClickZetta connection.""" |
|
|
|
max_retries = 3 |
|
|
|
retry_delay = 1.0 |
|
|
|
|
|
|
|
for attempt in range(max_retries): |
|
|
|
try: |
|
|
|
connection = clickzetta.connect( |
|
|
|
username=config.username, |
|
|
|
password=config.password, |
|
|
|
instance=config.instance, |
|
|
|
service=config.service, |
|
|
|
workspace=config.workspace, |
|
|
|
vcluster=config.vcluster, |
|
|
|
schema=config.schema_name, |
|
|
|
) |
|
|
|
|
|
|
|
# Configure connection session settings |
|
|
|
self._configure_connection(connection) |
|
|
|
logger.debug("Created new ClickZetta connection (attempt %d/%d)", attempt + 1, max_retries) |
|
|
|
return connection |
|
|
|
except Exception: |
|
|
|
logger.exception("ClickZetta connection attempt %d/%d failed", attempt + 1, max_retries) |
|
|
|
if attempt < max_retries - 1: |
|
|
|
time.sleep(retry_delay * (2**attempt)) |
|
|
|
else: |
|
|
|
raise |
|
|
|
|
|
|
|
raise RuntimeError(f"Failed to create ClickZetta connection after {max_retries} attempts") |
|
|
|
|
|
|
|
def _configure_connection(self, connection: "Connection") -> None: |
|
|
|
"""Configure connection session settings.""" |
|
|
|
try: |
|
|
|
with connection.cursor() as cursor: |
|
|
|
# Temporarily suppress ClickZetta client logging to reduce noise |
|
|
|
clickzetta_logger = logging.getLogger("clickzetta") |
|
|
|
original_level = clickzetta_logger.level |
|
|
|
clickzetta_logger.setLevel(logging.WARNING) |
|
|
|
|
|
|
|
try: |
|
|
|
# Use quote mode for string literal escaping |
|
|
|
cursor.execute("SET cz.sql.string.literal.escape.mode = 'quote'") |
|
|
|
|
|
|
|
# Apply performance optimization hints |
|
|
|
performance_hints = [ |
|
|
|
# Vector index optimization |
|
|
|
"SET cz.storage.parquet.vector.index.read.memory.cache = true", |
|
|
|
"SET cz.storage.parquet.vector.index.read.local.cache = false", |
|
|
|
# Query optimization |
|
|
|
"SET cz.sql.table.scan.push.down.filter = true", |
|
|
|
"SET cz.sql.table.scan.enable.ensure.filter = true", |
|
|
|
"SET cz.storage.always.prefetch.internal = true", |
|
|
|
"SET cz.optimizer.generate.columns.always.valid = true", |
|
|
|
"SET cz.sql.index.prewhere.enabled = true", |
|
|
|
# Storage optimization |
|
|
|
"SET cz.storage.parquet.enable.io.prefetch = false", |
|
|
|
"SET cz.optimizer.enable.mv.rewrite = false", |
|
|
|
"SET cz.sql.dump.as.lz4 = true", |
|
|
|
"SET cz.optimizer.limited.optimization.naive.query = true", |
|
|
|
"SET cz.sql.table.scan.enable.push.down.log = false", |
|
|
|
"SET cz.storage.use.file.format.local.stats = false", |
|
|
|
"SET cz.storage.local.file.object.cache.level = all", |
|
|
|
# Job execution optimization |
|
|
|
"SET cz.sql.job.fast.mode = true", |
|
|
|
"SET cz.storage.parquet.non.contiguous.read = true", |
|
|
|
"SET cz.sql.compaction.after.commit = true", |
|
|
|
] |
|
|
|
|
|
|
|
for hint in performance_hints: |
|
|
|
cursor.execute(hint) |
|
|
|
finally: |
|
|
|
# Restore original logging level |
|
|
|
clickzetta_logger.setLevel(original_level) |
|
|
|
|
|
|
|
except Exception: |
|
|
|
logger.exception("Failed to configure connection, continuing with defaults") |
|
|
|
|
|
|
|
def _is_connection_valid(self, connection: "Connection") -> bool: |
|
|
|
"""Check if connection is still valid.""" |
|
|
|
try: |
|
|
|
with connection.cursor() as cursor: |
|
|
|
cursor.execute("SELECT 1") |
|
|
|
return True |
|
|
|
except Exception: |
|
|
|
return False |
|
|
|
|
|
|
|
def get_connection(self, config: ClickzettaConfig) -> "Connection": |
|
|
|
"""Get a connection from the pool or create a new one.""" |
|
|
|
config_key = self._get_config_key(config) |
|
|
|
|
|
|
|
# Ensure pool lock exists |
|
|
|
if config_key not in self._pool_locks: |
|
|
|
with self._lock: |
|
|
|
if config_key not in self._pool_locks: |
|
|
|
self._pool_locks[config_key] = threading.Lock() |
|
|
|
self._pools[config_key] = [] |
|
|
|
|
|
|
|
with self._pool_locks[config_key]: |
|
|
|
pool = self._pools[config_key] |
|
|
|
current_time = time.time() |
|
|
|
|
|
|
|
# Try to reuse existing connection |
|
|
|
while pool: |
|
|
|
connection, last_used = pool.pop(0) |
|
|
|
|
|
|
|
# Check if connection is not expired and still valid |
|
|
|
if current_time - last_used < self._connection_timeout and self._is_connection_valid(connection): |
|
|
|
logger.debug("Reusing ClickZetta connection from pool") |
|
|
|
return connection |
|
|
|
else: |
|
|
|
# Connection expired or invalid, close it |
|
|
|
try: |
|
|
|
connection.close() |
|
|
|
except Exception: |
|
|
|
pass |
|
|
|
|
|
|
|
# No valid connection found, create new one |
|
|
|
return self._create_connection(config) |
|
|
|
|
|
|
|
def return_connection(self, config: ClickzettaConfig, connection: "Connection") -> None: |
|
|
|
"""Return a connection to the pool.""" |
|
|
|
config_key = self._get_config_key(config) |
|
|
|
|
|
|
|
if config_key not in self._pool_locks: |
|
|
|
# Pool was cleaned up, just close the connection |
|
|
|
try: |
|
|
|
connection.close() |
|
|
|
except Exception: |
|
|
|
pass |
|
|
|
return |
|
|
|
|
|
|
|
with self._pool_locks[config_key]: |
|
|
|
pool = self._pools[config_key] |
|
|
|
|
|
|
|
# Only return to pool if not at capacity and connection is valid |
|
|
|
if len(pool) < self._max_pool_size and self._is_connection_valid(connection): |
|
|
|
pool.append((connection, time.time())) |
|
|
|
logger.debug("Returned ClickZetta connection to pool") |
|
|
|
else: |
|
|
|
# Pool full or connection invalid, close it |
|
|
|
try: |
|
|
|
connection.close() |
|
|
|
except Exception: |
|
|
|
pass |
|
|
|
|
|
|
|
def _cleanup_expired_connections(self) -> None: |
|
|
|
"""Clean up expired connections from all pools.""" |
|
|
|
current_time = time.time() |
|
|
|
|
|
|
|
with self._lock: |
|
|
|
for config_key in list(self._pools.keys()): |
|
|
|
if config_key not in self._pool_locks: |
|
|
|
continue |
|
|
|
|
|
|
|
with self._pool_locks[config_key]: |
|
|
|
pool = self._pools[config_key] |
|
|
|
valid_connections = [] |
|
|
|
|
|
|
|
for connection, last_used in pool: |
|
|
|
if current_time - last_used < self._connection_timeout: |
|
|
|
valid_connections.append((connection, last_used)) |
|
|
|
else: |
|
|
|
try: |
|
|
|
connection.close() |
|
|
|
except Exception: |
|
|
|
pass |
|
|
|
|
|
|
|
self._pools[config_key] = valid_connections |
|
|
|
|
|
|
|
def _start_cleanup_thread(self) -> None: |
|
|
|
"""Start background thread for connection cleanup.""" |
|
|
|
|
|
|
|
def cleanup_worker(): |
|
|
|
while not self._shutdown: |
|
|
|
try: |
|
|
|
time.sleep(60) # Cleanup every minute |
|
|
|
if not self._shutdown: |
|
|
|
self._cleanup_expired_connections() |
|
|
|
except Exception: |
|
|
|
logger.exception("Error in connection pool cleanup") |
|
|
|
|
|
|
|
self._cleanup_thread = threading.Thread(target=cleanup_worker, daemon=True) |
|
|
|
self._cleanup_thread.start() |
|
|
|
|
|
|
|
def shutdown(self) -> None: |
|
|
|
"""Shutdown connection pool and close all connections.""" |
|
|
|
self._shutdown = True |
|
|
|
|
|
|
|
with self._lock: |
|
|
|
for config_key in list(self._pools.keys()): |
|
|
|
if config_key not in self._pool_locks: |
|
|
|
continue |
|
|
|
|
|
|
|
with self._pool_locks[config_key]: |
|
|
|
pool = self._pools[config_key] |
|
|
|
for connection, _ in pool: |
|
|
|
try: |
|
|
|
connection.close() |
|
|
|
except Exception: |
|
|
|
pass |
|
|
|
pool.clear() |
|
|
|
|
|
|
|
|
|
|
|
class ClickzettaVector(BaseVector): |
|
|
|
""" |
|
|
|
Clickzetta vector storage implementation. |
|
|
|
@@ -82,70 +321,74 @@ class ClickzettaVector(BaseVector): |
|
|
|
super().__init__(collection_name) |
|
|
|
self._config = config |
|
|
|
self._table_name = collection_name.replace("-", "_").lower() # Ensure valid table name |
|
|
|
self._connection: Optional[Connection] = None |
|
|
|
self._init_connection() |
|
|
|
self._connection_pool = ClickzettaConnectionPool.get_instance() |
|
|
|
self._init_write_queue() |
|
|
|
|
|
|
|
def _init_connection(self): |
|
|
|
"""Initialize Clickzetta connection.""" |
|
|
|
self._connection = clickzetta.connect( |
|
|
|
username=self._config.username, |
|
|
|
password=self._config.password, |
|
|
|
instance=self._config.instance, |
|
|
|
service=self._config.service, |
|
|
|
workspace=self._config.workspace, |
|
|
|
vcluster=self._config.vcluster, |
|
|
|
schema=self._config.schema_name, |
|
|
|
) |
|
|
|
def _get_connection(self) -> "Connection": |
|
|
|
"""Get a connection from the pool.""" |
|
|
|
return self._connection_pool.get_connection(self._config) |
|
|
|
|
|
|
|
def _return_connection(self, connection: "Connection") -> None: |
|
|
|
"""Return a connection to the pool.""" |
|
|
|
self._connection_pool.return_connection(self._config, connection) |
|
|
|
|
|
|
|
class ConnectionContext: |
|
|
|
"""Context manager for borrowing and returning connections.""" |
|
|
|
|
|
|
|
def __init__(self, vector_instance: "ClickzettaVector"): |
|
|
|
self.vector = vector_instance |
|
|
|
self.connection: Optional[Connection] = None |
|
|
|
|
|
|
|
# Set session parameters for better string handling and performance optimization |
|
|
|
if self._connection is not None: |
|
|
|
with self._connection.cursor() as cursor: |
|
|
|
# Use quote mode for string literal escaping to handle quotes better |
|
|
|
cursor.execute("SET cz.sql.string.literal.escape.mode = 'quote'") |
|
|
|
logger.info("Set string literal escape mode to 'quote' for better quote handling") |
|
|
|
def __enter__(self) -> "Connection": |
|
|
|
self.connection = self.vector._get_connection() |
|
|
|
return self.connection |
|
|
|
|
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb): |
|
|
|
if self.connection: |
|
|
|
self.vector._return_connection(self.connection) |
|
|
|
|
|
|
|
def get_connection_context(self) -> "ClickzettaVector.ConnectionContext": |
|
|
|
"""Get a connection context manager.""" |
|
|
|
return self.ConnectionContext(self) |
|
|
|
|
|
|
|
def _parse_metadata(self, raw_metadata: str, row_id: str) -> dict: |
|
|
|
""" |
|
|
|
Parse metadata from JSON string with proper error handling and fallback. |
|
|
|
|
|
|
|
# Performance optimization hints for vector operations |
|
|
|
self._set_performance_hints(cursor) |
|
|
|
Args: |
|
|
|
raw_metadata: Raw JSON string from database |
|
|
|
row_id: Row ID for fallback document_id |
|
|
|
|
|
|
|
def _set_performance_hints(self, cursor): |
|
|
|
"""Set ClickZetta performance optimization hints for vector operations.""" |
|
|
|
Returns: |
|
|
|
Parsed metadata dict with guaranteed required fields |
|
|
|
""" |
|
|
|
try: |
|
|
|
# Performance optimization hints for vector operations and query processing |
|
|
|
performance_hints = [ |
|
|
|
# Vector index optimization |
|
|
|
"SET cz.storage.parquet.vector.index.read.memory.cache = true", |
|
|
|
"SET cz.storage.parquet.vector.index.read.local.cache = false", |
|
|
|
# Query optimization |
|
|
|
"SET cz.sql.table.scan.push.down.filter = true", |
|
|
|
"SET cz.sql.table.scan.enable.ensure.filter = true", |
|
|
|
"SET cz.storage.always.prefetch.internal = true", |
|
|
|
"SET cz.optimizer.generate.columns.always.valid = true", |
|
|
|
"SET cz.sql.index.prewhere.enabled = true", |
|
|
|
# Storage optimization |
|
|
|
"SET cz.storage.parquet.enable.io.prefetch = false", |
|
|
|
"SET cz.optimizer.enable.mv.rewrite = false", |
|
|
|
"SET cz.sql.dump.as.lz4 = true", |
|
|
|
"SET cz.optimizer.limited.optimization.naive.query = true", |
|
|
|
"SET cz.sql.table.scan.enable.push.down.log = false", |
|
|
|
"SET cz.storage.use.file.format.local.stats = false", |
|
|
|
"SET cz.storage.local.file.object.cache.level = all", |
|
|
|
# Job execution optimization |
|
|
|
"SET cz.sql.job.fast.mode = true", |
|
|
|
"SET cz.storage.parquet.non.contiguous.read = true", |
|
|
|
"SET cz.sql.compaction.after.commit = true", |
|
|
|
] |
|
|
|
|
|
|
|
for hint in performance_hints: |
|
|
|
cursor.execute(hint) |
|
|
|
|
|
|
|
logger.info( |
|
|
|
"Applied %d performance optimization hints for ClickZetta vector operations", len(performance_hints) |
|
|
|
) |
|
|
|
if raw_metadata: |
|
|
|
metadata = json.loads(raw_metadata) |
|
|
|
|
|
|
|
except Exception: |
|
|
|
# Catch any errors setting performance hints but continue with defaults |
|
|
|
logger.exception("Failed to set some performance hints, continuing with default settings") |
|
|
|
# Handle double-encoded JSON |
|
|
|
if isinstance(metadata, str): |
|
|
|
metadata = json.loads(metadata) |
|
|
|
|
|
|
|
# Ensure we have a dict |
|
|
|
if not isinstance(metadata, dict): |
|
|
|
metadata = {} |
|
|
|
else: |
|
|
|
metadata = {} |
|
|
|
except (json.JSONDecodeError, TypeError): |
|
|
|
logger.exception("JSON parsing failed for metadata") |
|
|
|
# Fallback: extract document_id with regex |
|
|
|
doc_id_match = re.search(r'"document_id":\s*"([^"]+)"', raw_metadata or "") |
|
|
|
metadata = {"document_id": doc_id_match.group(1)} if doc_id_match else {} |
|
|
|
|
|
|
|
# Ensure required fields are set |
|
|
|
metadata["doc_id"] = row_id # segment id |
|
|
|
|
|
|
|
# Ensure document_id exists (critical for Dify's format_retrieval_documents) |
|
|
|
if "document_id" not in metadata: |
|
|
|
metadata["document_id"] = row_id # fallback to segment id |
|
|
|
|
|
|
|
return metadata |
|
|
|
|
|
|
|
@classmethod |
|
|
|
def _init_write_queue(cls): |
|
|
|
@@ -204,24 +447,33 @@ class ClickzettaVector(BaseVector): |
|
|
|
return "clickzetta" |
|
|
|
|
|
|
|
def _ensure_connection(self) -> "Connection": |
|
|
|
"""Ensure connection is available and return it.""" |
|
|
|
if self._connection is None: |
|
|
|
raise RuntimeError("Database connection not initialized") |
|
|
|
return self._connection |
|
|
|
"""Get a connection from the pool.""" |
|
|
|
return self._get_connection() |
|
|
|
|
|
|
|
def _table_exists(self) -> bool: |
|
|
|
"""Check if the table exists.""" |
|
|
|
try: |
|
|
|
connection = self._ensure_connection() |
|
|
|
with connection.cursor() as cursor: |
|
|
|
cursor.execute(f"DESC {self._config.schema_name}.{self._table_name}") |
|
|
|
return True |
|
|
|
except (RuntimeError, ValueError) as e: |
|
|
|
if "table or view not found" in str(e).lower(): |
|
|
|
with self.get_connection_context() as connection: |
|
|
|
with connection.cursor() as cursor: |
|
|
|
cursor.execute(f"DESC {self._config.schema_name}.{self._table_name}") |
|
|
|
return True |
|
|
|
except Exception as e: |
|
|
|
error_message = str(e).lower() |
|
|
|
# Handle ClickZetta specific "table or view not found" errors |
|
|
|
if any( |
|
|
|
phrase in error_message |
|
|
|
for phrase in ["table or view not found", "czlh-42000", "semantic analysis exception"] |
|
|
|
): |
|
|
|
logger.debug("Table %s.%s does not exist", self._config.schema_name, self._table_name) |
|
|
|
return False |
|
|
|
else: |
|
|
|
# Re-raise if it's a different error |
|
|
|
raise |
|
|
|
# For other connection/permission errors, log warning but return False to avoid blocking cleanup |
|
|
|
logger.exception( |
|
|
|
"Table existence check failed for %s.%s, assuming it doesn't exist", |
|
|
|
self._config.schema_name, |
|
|
|
self._table_name, |
|
|
|
) |
|
|
|
return False |
|
|
|
|
|
|
|
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): |
|
|
|
"""Create the collection and add initial documents.""" |
|
|
|
@@ -253,17 +505,17 @@ class ClickzettaVector(BaseVector): |
|
|
|
) COMMENT 'Dify RAG knowledge base vector storage table for document embeddings and content' |
|
|
|
""" |
|
|
|
|
|
|
|
connection = self._ensure_connection() |
|
|
|
with connection.cursor() as cursor: |
|
|
|
cursor.execute(create_table_sql) |
|
|
|
logger.info("Created table %s.%s", self._config.schema_name, self._table_name) |
|
|
|
with self.get_connection_context() as connection: |
|
|
|
with connection.cursor() as cursor: |
|
|
|
cursor.execute(create_table_sql) |
|
|
|
logger.info("Created table %s.%s", self._config.schema_name, self._table_name) |
|
|
|
|
|
|
|
# Create vector index |
|
|
|
self._create_vector_index(cursor) |
|
|
|
# Create vector index |
|
|
|
self._create_vector_index(cursor) |
|
|
|
|
|
|
|
# Create inverted index for full-text search if enabled |
|
|
|
if self._config.enable_inverted_index: |
|
|
|
self._create_inverted_index(cursor) |
|
|
|
# Create inverted index for full-text search if enabled |
|
|
|
if self._config.enable_inverted_index: |
|
|
|
self._create_inverted_index(cursor) |
|
|
|
|
|
|
|
def _create_vector_index(self, cursor): |
|
|
|
"""Create HNSW vector index for similarity search.""" |
|
|
|
@@ -432,39 +684,53 @@ class ClickzettaVector(BaseVector): |
|
|
|
f"VALUES (?, ?, CAST(? AS JSON), CAST(? AS VECTOR({vector_dimension})))" |
|
|
|
) |
|
|
|
|
|
|
|
connection = self._ensure_connection() |
|
|
|
with connection.cursor() as cursor: |
|
|
|
try: |
|
|
|
# Set session-level hints for batch insert operations |
|
|
|
# Note: executemany doesn't support hints parameter, so we set them as session variables |
|
|
|
cursor.execute("SET cz.sql.job.fast.mode = true") |
|
|
|
cursor.execute("SET cz.sql.compaction.after.commit = true") |
|
|
|
cursor.execute("SET cz.storage.always.prefetch.internal = true") |
|
|
|
|
|
|
|
cursor.executemany(insert_sql, data_rows) |
|
|
|
logger.info( |
|
|
|
"Inserted batch %d/%d (%d valid docs using parameterized query with VECTOR(%d) cast)", |
|
|
|
batch_index // batch_size + 1, |
|
|
|
total_batches, |
|
|
|
len(data_rows), |
|
|
|
vector_dimension, |
|
|
|
) |
|
|
|
except (RuntimeError, ValueError, TypeError, ConnectionError) as e: |
|
|
|
logger.exception("Parameterized SQL execution failed for %d documents", len(data_rows)) |
|
|
|
logger.exception("SQL template: %s", insert_sql) |
|
|
|
logger.exception("Sample data row: %s", data_rows[0] if data_rows else "None") |
|
|
|
raise |
|
|
|
with self.get_connection_context() as connection: |
|
|
|
with connection.cursor() as cursor: |
|
|
|
try: |
|
|
|
# Set session-level hints for batch insert operations |
|
|
|
# Note: executemany doesn't support hints parameter, so we set them as session variables |
|
|
|
# Temporarily suppress ClickZetta client logging to reduce noise |
|
|
|
clickzetta_logger = logging.getLogger("clickzetta") |
|
|
|
original_level = clickzetta_logger.level |
|
|
|
clickzetta_logger.setLevel(logging.WARNING) |
|
|
|
|
|
|
|
try: |
|
|
|
cursor.execute("SET cz.sql.job.fast.mode = true") |
|
|
|
cursor.execute("SET cz.sql.compaction.after.commit = true") |
|
|
|
cursor.execute("SET cz.storage.always.prefetch.internal = true") |
|
|
|
finally: |
|
|
|
# Restore original logging level |
|
|
|
clickzetta_logger.setLevel(original_level) |
|
|
|
|
|
|
|
cursor.executemany(insert_sql, data_rows) |
|
|
|
logger.info( |
|
|
|
"Inserted batch %d/%d (%d valid docs using parameterized query with VECTOR(%d) cast)", |
|
|
|
batch_index // batch_size + 1, |
|
|
|
total_batches, |
|
|
|
len(data_rows), |
|
|
|
vector_dimension, |
|
|
|
) |
|
|
|
except (RuntimeError, ValueError, TypeError, ConnectionError) as e: |
|
|
|
logger.exception("Parameterized SQL execution failed for %d documents", len(data_rows)) |
|
|
|
logger.exception("SQL template: %s", insert_sql) |
|
|
|
logger.exception("Sample data row: %s", data_rows[0] if data_rows else "None") |
|
|
|
raise |
|
|
|
|
|
|
|
def text_exists(self, id: str) -> bool: |
|
|
|
"""Check if a document exists by ID.""" |
|
|
|
# Check if table exists first |
|
|
|
if not self._table_exists(): |
|
|
|
return False |
|
|
|
|
|
|
|
safe_id = self._safe_doc_id(id) |
|
|
|
connection = self._ensure_connection() |
|
|
|
with connection.cursor() as cursor: |
|
|
|
cursor.execute( |
|
|
|
f"SELECT COUNT(*) FROM {self._config.schema_name}.{self._table_name} WHERE id = ?", [safe_id] |
|
|
|
) |
|
|
|
result = cursor.fetchone() |
|
|
|
return result[0] > 0 if result else False |
|
|
|
with self.get_connection_context() as connection: |
|
|
|
with connection.cursor() as cursor: |
|
|
|
cursor.execute( |
|
|
|
f"SELECT COUNT(*) FROM {self._config.schema_name}.{self._table_name} WHERE id = ?", |
|
|
|
binding_params=[safe_id], |
|
|
|
) |
|
|
|
result = cursor.fetchone() |
|
|
|
return result[0] > 0 if result else False |
|
|
|
|
|
|
|
def delete_by_ids(self, ids: list[str]) -> None: |
|
|
|
"""Delete documents by IDs.""" |
|
|
|
@@ -482,13 +748,14 @@ class ClickzettaVector(BaseVector): |
|
|
|
def _delete_by_ids_impl(self, ids: list[str]) -> None: |
|
|
|
"""Implementation of delete by IDs (executed in write worker thread).""" |
|
|
|
safe_ids = [self._safe_doc_id(id) for id in ids] |
|
|
|
# Create properly escaped string literals for SQL |
|
|
|
id_list = ",".join(f"'{id}'" for id in safe_ids) |
|
|
|
sql = f"DELETE FROM {self._config.schema_name}.{self._table_name} WHERE id IN ({id_list})" |
|
|
|
|
|
|
|
connection = self._ensure_connection() |
|
|
|
with connection.cursor() as cursor: |
|
|
|
cursor.execute(sql) |
|
|
|
# Use parameterized query to prevent SQL injection |
|
|
|
placeholders = ",".join("?" for _ in safe_ids) |
|
|
|
sql = f"DELETE FROM {self._config.schema_name}.{self._table_name} WHERE id IN ({placeholders})" |
|
|
|
|
|
|
|
with self.get_connection_context() as connection: |
|
|
|
with connection.cursor() as cursor: |
|
|
|
cursor.execute(sql, binding_params=safe_ids) |
|
|
|
|
|
|
|
def delete_by_metadata_field(self, key: str, value: str) -> None: |
|
|
|
"""Delete documents by metadata field.""" |
|
|
|
@@ -502,19 +769,28 @@ class ClickzettaVector(BaseVector): |
|
|
|
|
|
|
|
def _delete_by_metadata_field_impl(self, key: str, value: str) -> None: |
|
|
|
"""Implementation of delete by metadata field (executed in write worker thread).""" |
|
|
|
connection = self._ensure_connection() |
|
|
|
with connection.cursor() as cursor: |
|
|
|
# Using JSON path to filter with parameterized query |
|
|
|
# Note: JSON path requires literal key name, cannot be parameterized |
|
|
|
# Use json_extract_string function for ClickZetta compatibility |
|
|
|
sql = ( |
|
|
|
f"DELETE FROM {self._config.schema_name}.{self._table_name} " |
|
|
|
f"WHERE json_extract_string({Field.METADATA_KEY.value}, '$.{key}') = ?" |
|
|
|
) |
|
|
|
cursor.execute(sql, [value]) |
|
|
|
with self.get_connection_context() as connection: |
|
|
|
with connection.cursor() as cursor: |
|
|
|
# Using JSON path to filter with parameterized query |
|
|
|
# Note: JSON path requires literal key name, cannot be parameterized |
|
|
|
# Use json_extract_string function for ClickZetta compatibility |
|
|
|
sql = ( |
|
|
|
f"DELETE FROM {self._config.schema_name}.{self._table_name} " |
|
|
|
f"WHERE json_extract_string({Field.METADATA_KEY.value}, '$.{key}') = ?" |
|
|
|
) |
|
|
|
cursor.execute(sql, binding_params=[value]) |
|
|
|
|
|
|
|
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: |
|
|
|
"""Search for documents by vector similarity.""" |
|
|
|
# Check if table exists first |
|
|
|
if not self._table_exists(): |
|
|
|
logger.warning( |
|
|
|
"Table %s.%s does not exist, returning empty results", |
|
|
|
self._config.schema_name, |
|
|
|
self._table_name, |
|
|
|
) |
|
|
|
return [] |
|
|
|
|
|
|
|
top_k = kwargs.get("top_k", 10) |
|
|
|
score_threshold = kwargs.get("score_threshold", 0.0) |
|
|
|
document_ids_filter = kwargs.get("document_ids_filter") |
|
|
|
@@ -565,56 +841,31 @@ class ClickzettaVector(BaseVector): |
|
|
|
""" |
|
|
|
|
|
|
|
documents = [] |
|
|
|
connection = self._ensure_connection() |
|
|
|
with connection.cursor() as cursor: |
|
|
|
# Use hints parameter for vector search optimization |
|
|
|
search_hints = { |
|
|
|
"hints": { |
|
|
|
"sdk.job.timeout": 60, # Increase timeout for vector search |
|
|
|
"cz.sql.job.fast.mode": True, |
|
|
|
"cz.storage.parquet.vector.index.read.memory.cache": True, |
|
|
|
with self.get_connection_context() as connection: |
|
|
|
with connection.cursor() as cursor: |
|
|
|
# Use hints parameter for vector search optimization |
|
|
|
search_hints = { |
|
|
|
"hints": { |
|
|
|
"sdk.job.timeout": 60, # Increase timeout for vector search |
|
|
|
"cz.sql.job.fast.mode": True, |
|
|
|
"cz.storage.parquet.vector.index.read.memory.cache": True, |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
cursor.execute(search_sql, parameters=search_hints) |
|
|
|
results = cursor.fetchall() |
|
|
|
|
|
|
|
for row in results: |
|
|
|
# Parse metadata from JSON string (may be double-encoded) |
|
|
|
try: |
|
|
|
if row[2]: |
|
|
|
metadata = json.loads(row[2]) |
|
|
|
cursor.execute(search_sql, search_hints) |
|
|
|
results = cursor.fetchall() |
|
|
|
|
|
|
|
# If result is a string, it's double-encoded JSON - parse again |
|
|
|
if isinstance(metadata, str): |
|
|
|
metadata = json.loads(metadata) |
|
|
|
for row in results: |
|
|
|
# Parse metadata using centralized method |
|
|
|
metadata = self._parse_metadata(row[2], row[0]) |
|
|
|
|
|
|
|
if not isinstance(metadata, dict): |
|
|
|
metadata = {} |
|
|
|
# Add score based on distance |
|
|
|
if self._config.vector_distance_function == "cosine_distance": |
|
|
|
metadata["score"] = 1 - (row[3] / 2) |
|
|
|
else: |
|
|
|
metadata = {} |
|
|
|
except (json.JSONDecodeError, TypeError) as e: |
|
|
|
logger.exception("JSON parsing failed") |
|
|
|
# Fallback: extract document_id with regex |
|
|
|
import re |
|
|
|
metadata["score"] = 1 / (1 + row[3]) |
|
|
|
|
|
|
|
doc_id_match = re.search(r'"document_id":\s*"([^"]+)"', str(row[2] or "")) |
|
|
|
metadata = {"document_id": doc_id_match.group(1)} if doc_id_match else {} |
|
|
|
|
|
|
|
# Ensure required fields are set |
|
|
|
metadata["doc_id"] = row[0] # segment id |
|
|
|
|
|
|
|
# Ensure document_id exists (critical for Dify's format_retrieval_documents) |
|
|
|
if "document_id" not in metadata: |
|
|
|
metadata["document_id"] = row[0] # fallback to segment id |
|
|
|
|
|
|
|
# Add score based on distance |
|
|
|
if self._config.vector_distance_function == "cosine_distance": |
|
|
|
metadata["score"] = 1 - (row[3] / 2) |
|
|
|
else: |
|
|
|
metadata["score"] = 1 / (1 + row[3]) |
|
|
|
|
|
|
|
doc = Document(page_content=row[1], metadata=metadata) |
|
|
|
documents.append(doc) |
|
|
|
doc = Document(page_content=row[1], metadata=metadata) |
|
|
|
documents.append(doc) |
|
|
|
|
|
|
|
return documents |
|
|
|
|
|
|
|
@@ -624,6 +875,15 @@ class ClickzettaVector(BaseVector): |
|
|
|
logger.warning("Full-text search is not enabled. Enable inverted index in config.") |
|
|
|
return [] |
|
|
|
|
|
|
|
# Check if table exists first |
|
|
|
if not self._table_exists(): |
|
|
|
logger.warning( |
|
|
|
"Table %s.%s does not exist, returning empty results", |
|
|
|
self._config.schema_name, |
|
|
|
self._table_name, |
|
|
|
) |
|
|
|
return [] |
|
|
|
|
|
|
|
top_k = kwargs.get("top_k", 10) |
|
|
|
document_ids_filter = kwargs.get("document_ids_filter") |
|
|
|
|
|
|
|
@@ -659,62 +919,70 @@ class ClickzettaVector(BaseVector): |
|
|
|
""" |
|
|
|
|
|
|
|
documents = [] |
|
|
|
connection = self._ensure_connection() |
|
|
|
with connection.cursor() as cursor: |
|
|
|
try: |
|
|
|
# Use hints parameter for full-text search optimization |
|
|
|
fulltext_hints = { |
|
|
|
"hints": { |
|
|
|
"sdk.job.timeout": 30, # Timeout for full-text search |
|
|
|
"cz.sql.job.fast.mode": True, |
|
|
|
"cz.sql.index.prewhere.enabled": True, |
|
|
|
with self.get_connection_context() as connection: |
|
|
|
with connection.cursor() as cursor: |
|
|
|
try: |
|
|
|
# Use hints parameter for full-text search optimization |
|
|
|
fulltext_hints = { |
|
|
|
"hints": { |
|
|
|
"sdk.job.timeout": 30, # Timeout for full-text search |
|
|
|
"cz.sql.job.fast.mode": True, |
|
|
|
"cz.sql.index.prewhere.enabled": True, |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
cursor.execute(search_sql, parameters=fulltext_hints) |
|
|
|
results = cursor.fetchall() |
|
|
|
|
|
|
|
for row in results: |
|
|
|
# Parse metadata from JSON string (may be double-encoded) |
|
|
|
try: |
|
|
|
if row[2]: |
|
|
|
metadata = json.loads(row[2]) |
|
|
|
|
|
|
|
# If result is a string, it's double-encoded JSON - parse again |
|
|
|
if isinstance(metadata, str): |
|
|
|
metadata = json.loads(metadata) |
|
|
|
|
|
|
|
if not isinstance(metadata, dict): |
|
|
|
cursor.execute(search_sql, fulltext_hints) |
|
|
|
results = cursor.fetchall() |
|
|
|
|
|
|
|
for row in results: |
|
|
|
# Parse metadata from JSON string (may be double-encoded) |
|
|
|
try: |
|
|
|
if row[2]: |
|
|
|
metadata = json.loads(row[2]) |
|
|
|
|
|
|
|
# If result is a string, it's double-encoded JSON - parse again |
|
|
|
if isinstance(metadata, str): |
|
|
|
metadata = json.loads(metadata) |
|
|
|
|
|
|
|
if not isinstance(metadata, dict): |
|
|
|
metadata = {} |
|
|
|
else: |
|
|
|
metadata = {} |
|
|
|
else: |
|
|
|
metadata = {} |
|
|
|
except (json.JSONDecodeError, TypeError) as e: |
|
|
|
logger.exception("JSON parsing failed") |
|
|
|
# Fallback: extract document_id with regex |
|
|
|
import re |
|
|
|
except (json.JSONDecodeError, TypeError) as e: |
|
|
|
logger.exception("JSON parsing failed") |
|
|
|
# Fallback: extract document_id with regex |
|
|
|
|
|
|
|
doc_id_match = re.search(r'"document_id":\s*"([^"]+)"', str(row[2] or "")) |
|
|
|
metadata = {"document_id": doc_id_match.group(1)} if doc_id_match else {} |
|
|
|
doc_id_match = re.search(r'"document_id":\s*"([^"]+)"', str(row[2] or "")) |
|
|
|
metadata = {"document_id": doc_id_match.group(1)} if doc_id_match else {} |
|
|
|
|
|
|
|
# Ensure required fields are set |
|
|
|
metadata["doc_id"] = row[0] # segment id |
|
|
|
# Ensure required fields are set |
|
|
|
metadata["doc_id"] = row[0] # segment id |
|
|
|
|
|
|
|
# Ensure document_id exists (critical for Dify's format_retrieval_documents) |
|
|
|
if "document_id" not in metadata: |
|
|
|
metadata["document_id"] = row[0] # fallback to segment id |
|
|
|
# Ensure document_id exists (critical for Dify's format_retrieval_documents) |
|
|
|
if "document_id" not in metadata: |
|
|
|
metadata["document_id"] = row[0] # fallback to segment id |
|
|
|
|
|
|
|
# Add a relevance score for full-text search |
|
|
|
metadata["score"] = 1.0 # Clickzetta doesn't provide relevance scores |
|
|
|
doc = Document(page_content=row[1], metadata=metadata) |
|
|
|
documents.append(doc) |
|
|
|
except (RuntimeError, ValueError, TypeError, ConnectionError) as e: |
|
|
|
logger.exception("Full-text search failed") |
|
|
|
# Fallback to LIKE search if full-text search fails |
|
|
|
return self._search_by_like(query, **kwargs) |
|
|
|
# Add a relevance score for full-text search |
|
|
|
metadata["score"] = 1.0 # Clickzetta doesn't provide relevance scores |
|
|
|
doc = Document(page_content=row[1], metadata=metadata) |
|
|
|
documents.append(doc) |
|
|
|
except (RuntimeError, ValueError, TypeError, ConnectionError) as e: |
|
|
|
logger.exception("Full-text search failed") |
|
|
|
# Fallback to LIKE search if full-text search fails |
|
|
|
return self._search_by_like(query, **kwargs) |
|
|
|
|
|
|
|
return documents |
|
|
|
|
|
|
|
def _search_by_like(self, query: str, **kwargs: Any) -> list[Document]: |
|
|
|
"""Fallback search using LIKE operator.""" |
|
|
|
# Check if table exists first |
|
|
|
if not self._table_exists(): |
|
|
|
logger.warning( |
|
|
|
"Table %s.%s does not exist, returning empty results", |
|
|
|
self._config.schema_name, |
|
|
|
self._table_name, |
|
|
|
) |
|
|
|
return [] |
|
|
|
|
|
|
|
top_k = kwargs.get("top_k", 10) |
|
|
|
document_ids_filter = kwargs.get("document_ids_filter") |
|
|
|
|
|
|
|
@@ -746,58 +1014,33 @@ class ClickzettaVector(BaseVector): |
|
|
|
""" |
|
|
|
|
|
|
|
documents = [] |
|
|
|
connection = self._ensure_connection() |
|
|
|
with connection.cursor() as cursor: |
|
|
|
# Use hints parameter for LIKE search optimization |
|
|
|
like_hints = { |
|
|
|
"hints": { |
|
|
|
"sdk.job.timeout": 20, # Timeout for LIKE search |
|
|
|
"cz.sql.job.fast.mode": True, |
|
|
|
with self.get_connection_context() as connection: |
|
|
|
with connection.cursor() as cursor: |
|
|
|
# Use hints parameter for LIKE search optimization |
|
|
|
like_hints = { |
|
|
|
"hints": { |
|
|
|
"sdk.job.timeout": 20, # Timeout for LIKE search |
|
|
|
"cz.sql.job.fast.mode": True, |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
cursor.execute(search_sql, parameters=like_hints) |
|
|
|
results = cursor.fetchall() |
|
|
|
|
|
|
|
for row in results: |
|
|
|
# Parse metadata from JSON string (may be double-encoded) |
|
|
|
try: |
|
|
|
if row[2]: |
|
|
|
metadata = json.loads(row[2]) |
|
|
|
|
|
|
|
# If result is a string, it's double-encoded JSON - parse again |
|
|
|
if isinstance(metadata, str): |
|
|
|
metadata = json.loads(metadata) |
|
|
|
|
|
|
|
if not isinstance(metadata, dict): |
|
|
|
metadata = {} |
|
|
|
else: |
|
|
|
metadata = {} |
|
|
|
except (json.JSONDecodeError, TypeError) as e: |
|
|
|
logger.exception("JSON parsing failed") |
|
|
|
# Fallback: extract document_id with regex |
|
|
|
import re |
|
|
|
|
|
|
|
doc_id_match = re.search(r'"document_id":\s*"([^"]+)"', str(row[2] or "")) |
|
|
|
metadata = {"document_id": doc_id_match.group(1)} if doc_id_match else {} |
|
|
|
|
|
|
|
# Ensure required fields are set |
|
|
|
metadata["doc_id"] = row[0] # segment id |
|
|
|
cursor.execute(search_sql, like_hints) |
|
|
|
results = cursor.fetchall() |
|
|
|
|
|
|
|
# Ensure document_id exists (critical for Dify's format_retrieval_documents) |
|
|
|
if "document_id" not in metadata: |
|
|
|
metadata["document_id"] = row[0] # fallback to segment id |
|
|
|
for row in results: |
|
|
|
# Parse metadata using centralized method |
|
|
|
metadata = self._parse_metadata(row[2], row[0]) |
|
|
|
|
|
|
|
metadata["score"] = 0.5 # Lower score for LIKE search |
|
|
|
doc = Document(page_content=row[1], metadata=metadata) |
|
|
|
documents.append(doc) |
|
|
|
metadata["score"] = 0.5 # Lower score for LIKE search |
|
|
|
doc = Document(page_content=row[1], metadata=metadata) |
|
|
|
documents.append(doc) |
|
|
|
|
|
|
|
return documents |
|
|
|
|
|
|
|
def delete(self) -> None: |
|
|
|
"""Delete the entire collection.""" |
|
|
|
connection = self._ensure_connection() |
|
|
|
with connection.cursor() as cursor: |
|
|
|
cursor.execute(f"DROP TABLE IF EXISTS {self._config.schema_name}.{self._table_name}") |
|
|
|
with self.get_connection_context() as connection: |
|
|
|
with connection.cursor() as cursor: |
|
|
|
cursor.execute(f"DROP TABLE IF EXISTS {self._config.schema_name}.{self._table_name}") |
|
|
|
|
|
|
|
def _format_vector_simple(self, vector: list[float]) -> str: |
|
|
|
"""Simple vector formatting for SQL queries.""" |