|
|
|
@@ -3,7 +3,7 @@ import logging |
|
|
|
import queue |
|
|
|
import threading |
|
|
|
import uuid |
|
|
|
from typing import Any, Optional, TYPE_CHECKING |
|
|
|
from typing import TYPE_CHECKING, Any, Optional |
|
|
|
|
|
|
|
import clickzetta # type: ignore |
|
|
|
from pydantic import BaseModel, model_validator |
|
|
|
@@ -82,7 +82,7 @@ class ClickzettaVector(BaseVector): |
|
|
|
super().__init__(collection_name) |
|
|
|
self._config = config |
|
|
|
self._table_name = collection_name.replace("-", "_").lower() # Ensure valid table name |
|
|
|
self._connection: Optional["Connection"] = None |
|
|
|
self._connection: Optional[Connection] = None |
|
|
|
self._init_connection() |
|
|
|
self._init_write_queue() |
|
|
|
|
|
|
|
@@ -95,7 +95,7 @@ class ClickzettaVector(BaseVector): |
|
|
|
service=self._config.service, |
|
|
|
workspace=self._config.workspace, |
|
|
|
vcluster=self._config.vcluster, |
|
|
|
schema=self._config.schema_name |
|
|
|
schema=self._config.schema_name, |
|
|
|
) |
|
|
|
|
|
|
|
# Set session parameters for better string handling and performance optimization |
|
|
|
@@ -116,14 +116,12 @@ class ClickzettaVector(BaseVector): |
|
|
|
# Vector index optimization |
|
|
|
"SET cz.storage.parquet.vector.index.read.memory.cache = true", |
|
|
|
"SET cz.storage.parquet.vector.index.read.local.cache = false", |
|
|
|
|
|
|
|
# Query optimization |
|
|
|
"SET cz.sql.table.scan.push.down.filter = true", |
|
|
|
"SET cz.sql.table.scan.enable.ensure.filter = true", |
|
|
|
"SET cz.storage.always.prefetch.internal = true", |
|
|
|
"SET cz.optimizer.generate.columns.always.valid = true", |
|
|
|
"SET cz.sql.index.prewhere.enabled = true", |
|
|
|
|
|
|
|
# Storage optimization |
|
|
|
"SET cz.storage.parquet.enable.io.prefetch = false", |
|
|
|
"SET cz.optimizer.enable.mv.rewrite = false", |
|
|
|
@@ -132,17 +130,18 @@ class ClickzettaVector(BaseVector): |
|
|
|
"SET cz.sql.table.scan.enable.push.down.log = false", |
|
|
|
"SET cz.storage.use.file.format.local.stats = false", |
|
|
|
"SET cz.storage.local.file.object.cache.level = all", |
|
|
|
|
|
|
|
# Job execution optimization |
|
|
|
"SET cz.sql.job.fast.mode = true", |
|
|
|
"SET cz.storage.parquet.non.contiguous.read = true", |
|
|
|
"SET cz.sql.compaction.after.commit = true" |
|
|
|
"SET cz.sql.compaction.after.commit = true", |
|
|
|
] |
|
|
|
|
|
|
|
for hint in performance_hints: |
|
|
|
cursor.execute(hint) |
|
|
|
|
|
|
|
logger.info("Applied %d performance optimization hints for ClickZetta vector operations", len(performance_hints)) |
|
|
|
logger.info( |
|
|
|
"Applied %d performance optimization hints for ClickZetta vector operations", len(performance_hints) |
|
|
|
) |
|
|
|
|
|
|
|
except Exception: |
|
|
|
# Catch any errors setting performance hints but continue with defaults |
|
|
|
@@ -298,9 +297,7 @@ class ClickzettaVector(BaseVector): |
|
|
|
logger.info("Created vector index: %s", index_name) |
|
|
|
except (RuntimeError, ValueError) as e: |
|
|
|
error_msg = str(e).lower() |
|
|
|
if ("already exists" in error_msg or |
|
|
|
"already has index" in error_msg or |
|
|
|
"with the same type" in error_msg): |
|
|
|
if "already exists" in error_msg or "already has index" in error_msg or "with the same type" in error_msg: |
|
|
|
logger.info("Vector index already exists: %s", e) |
|
|
|
else: |
|
|
|
logger.exception("Failed to create vector index") |
|
|
|
@@ -318,9 +315,11 @@ class ClickzettaVector(BaseVector): |
|
|
|
for idx in existing_indexes: |
|
|
|
idx_str = str(idx).lower() |
|
|
|
# More precise check: look for inverted index specifically on the content column |
|
|
|
if ("inverted" in idx_str and |
|
|
|
Field.CONTENT_KEY.value.lower() in idx_str and |
|
|
|
(index_name.lower() in idx_str or f"idx_{self._table_name}_text" in idx_str)): |
|
|
|
if ( |
|
|
|
"inverted" in idx_str |
|
|
|
and Field.CONTENT_KEY.value.lower() in idx_str |
|
|
|
and (index_name.lower() in idx_str or f"idx_{self._table_name}_text" in idx_str) |
|
|
|
): |
|
|
|
logger.info("Inverted index already exists on column %s: %s", Field.CONTENT_KEY.value, idx) |
|
|
|
return |
|
|
|
except (RuntimeError, ValueError) as e: |
|
|
|
@@ -340,11 +339,12 @@ class ClickzettaVector(BaseVector): |
|
|
|
except (RuntimeError, ValueError) as e: |
|
|
|
error_msg = str(e).lower() |
|
|
|
# Handle ClickZetta specific error messages |
|
|
|
if (("already exists" in error_msg or |
|
|
|
"already has index" in error_msg or |
|
|
|
"with the same type" in error_msg or |
|
|
|
"cannot create inverted index" in error_msg) and |
|
|
|
"already has index" in error_msg): |
|
|
|
if ( |
|
|
|
"already exists" in error_msg |
|
|
|
or "already has index" in error_msg |
|
|
|
or "with the same type" in error_msg |
|
|
|
or "cannot create inverted index" in error_msg |
|
|
|
) and "already has index" in error_msg: |
|
|
|
logger.info("Inverted index already exists on column %s", Field.CONTENT_KEY.value) |
|
|
|
# Try to get the existing index name for logging |
|
|
|
try: |
|
|
|
@@ -360,7 +360,6 @@ class ClickzettaVector(BaseVector): |
|
|
|
logger.warning("Failed to create inverted index: %s", e) |
|
|
|
# Continue without inverted index - full-text search will fall back to LIKE |
|
|
|
|
|
|
|
|
|
|
|
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): |
|
|
|
"""Add documents with embeddings to the collection.""" |
|
|
|
if not documents: |
|
|
|
@@ -370,14 +369,20 @@ class ClickzettaVector(BaseVector): |
|
|
|
total_batches = (len(documents) + batch_size - 1) // batch_size |
|
|
|
|
|
|
|
for i in range(0, len(documents), batch_size): |
|
|
|
batch_docs = documents[i:i + batch_size] |
|
|
|
batch_embeddings = embeddings[i:i + batch_size] |
|
|
|
batch_docs = documents[i : i + batch_size] |
|
|
|
batch_embeddings = embeddings[i : i + batch_size] |
|
|
|
|
|
|
|
# Execute batch insert through write queue |
|
|
|
self._execute_write(self._insert_batch, batch_docs, batch_embeddings, i, batch_size, total_batches) |
|
|
|
|
|
|
|
def _insert_batch(self, batch_docs: list[Document], batch_embeddings: list[list[float]], |
|
|
|
batch_index: int, batch_size: int, total_batches: int): |
|
|
|
def _insert_batch( |
|
|
|
self, |
|
|
|
batch_docs: list[Document], |
|
|
|
batch_embeddings: list[list[float]], |
|
|
|
batch_index: int, |
|
|
|
batch_size: int, |
|
|
|
total_batches: int, |
|
|
|
): |
|
|
|
"""Insert a batch of documents using parameterized queries (executed in write worker thread).""" |
|
|
|
if not batch_docs or not batch_embeddings: |
|
|
|
logger.warning("Empty batch provided, skipping insertion") |
|
|
|
@@ -411,7 +416,7 @@ class ClickzettaVector(BaseVector): |
|
|
|
|
|
|
|
# According to ClickZetta docs, vector should be formatted as array string |
|
|
|
# for external systems: '[1.0, 2.0, 3.0]' |
|
|
|
vector_str = '[' + ','.join(map(str, embedding)) + ']' |
|
|
|
vector_str = "[" + ",".join(map(str, embedding)) + "]" |
|
|
|
data_rows.append([doc_id, content, metadata_json, vector_str]) |
|
|
|
|
|
|
|
# Check if we have any valid data to insert |
|
|
|
@@ -438,13 +443,16 @@ class ClickzettaVector(BaseVector): |
|
|
|
|
|
|
|
cursor.executemany(insert_sql, data_rows) |
|
|
|
logger.info( |
|
|
|
f"Inserted batch {batch_index // batch_size + 1}/{total_batches} " |
|
|
|
f"({len(data_rows)} valid docs using parameterized query with VECTOR({vector_dimension}) cast)" |
|
|
|
"Inserted batch %d/%d (%d valid docs using parameterized query with VECTOR(%d) cast)", |
|
|
|
batch_index // batch_size + 1, |
|
|
|
total_batches, |
|
|
|
len(data_rows), |
|
|
|
vector_dimension, |
|
|
|
) |
|
|
|
except (RuntimeError, ValueError, TypeError, ConnectionError) as e: |
|
|
|
logger.exception("Parameterized SQL execution failed for %d documents: %s", len(data_rows), e) |
|
|
|
logger.exception("Parameterized SQL execution failed for %d documents", len(data_rows)) |
|
|
|
logger.exception("SQL template: %s", insert_sql) |
|
|
|
logger.exception("Sample data row: %s", data_rows[0] if data_rows else 'None') |
|
|
|
logger.exception("Sample data row: %s", data_rows[0] if data_rows else "None") |
|
|
|
raise |
|
|
|
|
|
|
|
def text_exists(self, id: str) -> bool: |
|
|
|
@@ -453,8 +461,7 @@ class ClickzettaVector(BaseVector): |
|
|
|
connection = self._ensure_connection() |
|
|
|
with connection.cursor() as cursor: |
|
|
|
cursor.execute( |
|
|
|
f"SELECT COUNT(*) FROM {self._config.schema_name}.{self._table_name} WHERE id = ?", |
|
|
|
[safe_id] |
|
|
|
f"SELECT COUNT(*) FROM {self._config.schema_name}.{self._table_name} WHERE id = ?", [safe_id] |
|
|
|
) |
|
|
|
result = cursor.fetchone() |
|
|
|
return result[0] > 0 if result else False |
|
|
|
@@ -500,8 +507,10 @@ class ClickzettaVector(BaseVector): |
|
|
|
# Using JSON path to filter with parameterized query |
|
|
|
# Note: JSON path requires literal key name, cannot be parameterized |
|
|
|
# Use json_extract_string function for ClickZetta compatibility |
|
|
|
sql = (f"DELETE FROM {self._config.schema_name}.{self._table_name} " |
|
|
|
f"WHERE json_extract_string({Field.METADATA_KEY.value}, '$.{key}') = ?") |
|
|
|
sql = ( |
|
|
|
f"DELETE FROM {self._config.schema_name}.{self._table_name} " |
|
|
|
f"WHERE json_extract_string({Field.METADATA_KEY.value}, '$.{key}') = ?" |
|
|
|
) |
|
|
|
cursor.execute(sql, [value]) |
|
|
|
|
|
|
|
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: |
|
|
|
@@ -532,15 +541,15 @@ class ClickzettaVector(BaseVector): |
|
|
|
distance_func = "COSINE_DISTANCE" |
|
|
|
if score_threshold > 0: |
|
|
|
query_vector_str = f"CAST('[{self._format_vector_simple(query_vector)}]' AS VECTOR({vector_dimension}))" |
|
|
|
filter_clauses.append(f"{distance_func}({Field.VECTOR.value}, " |
|
|
|
f"{query_vector_str}) < {2 - score_threshold}") |
|
|
|
filter_clauses.append( |
|
|
|
f"{distance_func}({Field.VECTOR.value}, {query_vector_str}) < {2 - score_threshold}" |
|
|
|
) |
|
|
|
else: |
|
|
|
# For L2 distance, smaller is better |
|
|
|
distance_func = "L2_DISTANCE" |
|
|
|
if score_threshold > 0: |
|
|
|
query_vector_str = f"CAST('[{self._format_vector_simple(query_vector)}]' AS VECTOR({vector_dimension}))" |
|
|
|
filter_clauses.append(f"{distance_func}({Field.VECTOR.value}, " |
|
|
|
f"{query_vector_str}) < {score_threshold}") |
|
|
|
filter_clauses.append(f"{distance_func}({Field.VECTOR.value}, {query_vector_str}) < {score_threshold}") |
|
|
|
|
|
|
|
where_clause = " AND ".join(filter_clauses) if filter_clauses else "1=1" |
|
|
|
|
|
|
|
@@ -560,10 +569,10 @@ class ClickzettaVector(BaseVector): |
|
|
|
with connection.cursor() as cursor: |
|
|
|
# Use hints parameter for vector search optimization |
|
|
|
search_hints = { |
|
|
|
'hints': { |
|
|
|
'sdk.job.timeout': 60, # Increase timeout for vector search |
|
|
|
'cz.sql.job.fast.mode': True, |
|
|
|
'cz.storage.parquet.vector.index.read.memory.cache': True |
|
|
|
"hints": { |
|
|
|
"sdk.job.timeout": 60, # Increase timeout for vector search |
|
|
|
"cz.sql.job.fast.mode": True, |
|
|
|
"cz.storage.parquet.vector.index.read.memory.cache": True, |
|
|
|
} |
|
|
|
} |
|
|
|
cursor.execute(search_sql, parameters=search_hints) |
|
|
|
@@ -584,10 +593,11 @@ class ClickzettaVector(BaseVector): |
|
|
|
else: |
|
|
|
metadata = {} |
|
|
|
except (json.JSONDecodeError, TypeError) as e: |
|
|
|
logger.error("JSON parsing failed: %s", e) |
|
|
|
logger.exception("JSON parsing failed") |
|
|
|
# Fallback: extract document_id with regex |
|
|
|
import re |
|
|
|
doc_id_match = re.search(r'"document_id":\s*"([^"]+)"', str(row[2] or '')) |
|
|
|
|
|
|
|
doc_id_match = re.search(r'"document_id":\s*"([^"]+)"', str(row[2] or "")) |
|
|
|
metadata = {"document_id": doc_id_match.group(1)} if doc_id_match else {} |
|
|
|
|
|
|
|
# Ensure required fields are set |
|
|
|
@@ -654,10 +664,10 @@ class ClickzettaVector(BaseVector): |
|
|
|
try: |
|
|
|
# Use hints parameter for full-text search optimization |
|
|
|
fulltext_hints = { |
|
|
|
'hints': { |
|
|
|
'sdk.job.timeout': 30, # Timeout for full-text search |
|
|
|
'cz.sql.job.fast.mode': True, |
|
|
|
'cz.sql.index.prewhere.enabled': True |
|
|
|
"hints": { |
|
|
|
"sdk.job.timeout": 30, # Timeout for full-text search |
|
|
|
"cz.sql.job.fast.mode": True, |
|
|
|
"cz.sql.index.prewhere.enabled": True, |
|
|
|
} |
|
|
|
} |
|
|
|
cursor.execute(search_sql, parameters=fulltext_hints) |
|
|
|
@@ -678,10 +688,11 @@ class ClickzettaVector(BaseVector): |
|
|
|
else: |
|
|
|
metadata = {} |
|
|
|
except (json.JSONDecodeError, TypeError) as e: |
|
|
|
logger.error("JSON parsing failed: %s", e) |
|
|
|
logger.exception("JSON parsing failed") |
|
|
|
# Fallback: extract document_id with regex |
|
|
|
import re |
|
|
|
doc_id_match = re.search(r'"document_id":\s*"([^"]+)"', str(row[2] or '')) |
|
|
|
|
|
|
|
doc_id_match = re.search(r'"document_id":\s*"([^"]+)"', str(row[2] or "")) |
|
|
|
metadata = {"document_id": doc_id_match.group(1)} if doc_id_match else {} |
|
|
|
|
|
|
|
# Ensure required fields are set |
|
|
|
@@ -739,9 +750,9 @@ class ClickzettaVector(BaseVector): |
|
|
|
with connection.cursor() as cursor: |
|
|
|
# Use hints parameter for LIKE search optimization |
|
|
|
like_hints = { |
|
|
|
'hints': { |
|
|
|
'sdk.job.timeout': 20, # Timeout for LIKE search |
|
|
|
'cz.sql.job.fast.mode': True |
|
|
|
"hints": { |
|
|
|
"sdk.job.timeout": 20, # Timeout for LIKE search |
|
|
|
"cz.sql.job.fast.mode": True, |
|
|
|
} |
|
|
|
} |
|
|
|
cursor.execute(search_sql, parameters=like_hints) |
|
|
|
@@ -762,10 +773,11 @@ class ClickzettaVector(BaseVector): |
|
|
|
else: |
|
|
|
metadata = {} |
|
|
|
except (json.JSONDecodeError, TypeError) as e: |
|
|
|
logger.error("JSON parsing failed: %s", e) |
|
|
|
logger.exception("JSON parsing failed") |
|
|
|
# Fallback: extract document_id with regex |
|
|
|
import re |
|
|
|
doc_id_match = re.search(r'"document_id":\s*"([^"]+)"', str(row[2] or '')) |
|
|
|
|
|
|
|
doc_id_match = re.search(r'"document_id":\s*"([^"]+)"', str(row[2] or "")) |
|
|
|
metadata = {"document_id": doc_id_match.group(1)} if doc_id_match else {} |
|
|
|
|
|
|
|
# Ensure required fields are set |
|
|
|
@@ -787,10 +799,9 @@ class ClickzettaVector(BaseVector): |
|
|
|
with connection.cursor() as cursor: |
|
|
|
cursor.execute(f"DROP TABLE IF EXISTS {self._config.schema_name}.{self._table_name}") |
|
|
|
|
|
|
|
|
|
|
|
def _format_vector_simple(self, vector: list[float]) -> str: |
|
|
|
"""Simple vector formatting for SQL queries.""" |
|
|
|
return ','.join(map(str, vector)) |
|
|
|
return ",".join(map(str, vector)) |
|
|
|
|
|
|
|
def _safe_doc_id(self, doc_id: str) -> str: |
|
|
|
"""Ensure doc_id is safe for SQL and doesn't contain special characters.""" |
|
|
|
@@ -799,13 +810,12 @@ class ClickzettaVector(BaseVector): |
|
|
|
# Remove or replace potentially problematic characters |
|
|
|
safe_id = str(doc_id) |
|
|
|
# Only allow alphanumeric, hyphens, underscores |
|
|
|
safe_id = ''.join(c for c in safe_id if c.isalnum() or c in '-_') |
|
|
|
safe_id = "".join(c for c in safe_id if c.isalnum() or c in "-_") |
|
|
|
if not safe_id: # If all characters were removed |
|
|
|
return str(uuid.uuid4()) |
|
|
|
return safe_id[:255] # Limit length |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ClickzettaVectorFactory(AbstractVectorFactory): |
|
|
|
"""Factory for creating Clickzetta vector instances.""" |
|
|
|
|
|
|
|
@@ -831,4 +841,3 @@ class ClickzettaVectorFactory(AbstractVectorFactory): |
|
|
|
collection_name = Dataset.gen_collection_name_by_id(dataset.id).lower() |
|
|
|
|
|
|
|
return ClickzettaVector(collection_name=collection_name, config=config) |
|
|
|
|