소스 검색

fix: ensure vector database cleanup on dataset deletion regardless of document presence (affects all 33 vector databases) (#23574)

Co-authored-by: Claude <noreply@anthropic.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
tags/1.7.2
yunqiqiliang 2 달 전
부모
커밋
62772e8871
No account linked to committer's email address

+ 1
- 7
.gitignore 파일 보기

# AI Assistant # AI Assistant
.roo/ .roo/
api/.env.backup api/.env.backup

# Clickzetta test credentials
.env.clickzetta
.env.clickzetta.test

# Clickzetta plugin development folder (keep local, ignore for PR)
clickzetta/
/clickzetta

+ 68
- 59
api/core/rag/datasource/vdb/clickzetta/clickzetta_vector.py 파일 보기

import queue import queue
import threading import threading
import uuid import uuid
from typing import Any, Optional, TYPE_CHECKING
from typing import TYPE_CHECKING, Any, Optional


import clickzetta # type: ignore import clickzetta # type: ignore
from pydantic import BaseModel, model_validator from pydantic import BaseModel, model_validator
super().__init__(collection_name) super().__init__(collection_name)
self._config = config self._config = config
self._table_name = collection_name.replace("-", "_").lower() # Ensure valid table name self._table_name = collection_name.replace("-", "_").lower() # Ensure valid table name
self._connection: Optional["Connection"] = None
self._connection: Optional[Connection] = None
self._init_connection() self._init_connection()
self._init_write_queue() self._init_write_queue()


service=self._config.service, service=self._config.service,
workspace=self._config.workspace, workspace=self._config.workspace,
vcluster=self._config.vcluster, vcluster=self._config.vcluster,
schema=self._config.schema_name
schema=self._config.schema_name,
) )


# Set session parameters for better string handling and performance optimization # Set session parameters for better string handling and performance optimization
# Vector index optimization # Vector index optimization
"SET cz.storage.parquet.vector.index.read.memory.cache = true", "SET cz.storage.parquet.vector.index.read.memory.cache = true",
"SET cz.storage.parquet.vector.index.read.local.cache = false", "SET cz.storage.parquet.vector.index.read.local.cache = false",

# Query optimization # Query optimization
"SET cz.sql.table.scan.push.down.filter = true", "SET cz.sql.table.scan.push.down.filter = true",
"SET cz.sql.table.scan.enable.ensure.filter = true", "SET cz.sql.table.scan.enable.ensure.filter = true",
"SET cz.storage.always.prefetch.internal = true", "SET cz.storage.always.prefetch.internal = true",
"SET cz.optimizer.generate.columns.always.valid = true", "SET cz.optimizer.generate.columns.always.valid = true",
"SET cz.sql.index.prewhere.enabled = true", "SET cz.sql.index.prewhere.enabled = true",

# Storage optimization # Storage optimization
"SET cz.storage.parquet.enable.io.prefetch = false", "SET cz.storage.parquet.enable.io.prefetch = false",
"SET cz.optimizer.enable.mv.rewrite = false", "SET cz.optimizer.enable.mv.rewrite = false",
"SET cz.sql.table.scan.enable.push.down.log = false", "SET cz.sql.table.scan.enable.push.down.log = false",
"SET cz.storage.use.file.format.local.stats = false", "SET cz.storage.use.file.format.local.stats = false",
"SET cz.storage.local.file.object.cache.level = all", "SET cz.storage.local.file.object.cache.level = all",

# Job execution optimization # Job execution optimization
"SET cz.sql.job.fast.mode = true", "SET cz.sql.job.fast.mode = true",
"SET cz.storage.parquet.non.contiguous.read = true", "SET cz.storage.parquet.non.contiguous.read = true",
"SET cz.sql.compaction.after.commit = true"
"SET cz.sql.compaction.after.commit = true",
] ]


for hint in performance_hints: for hint in performance_hints:
cursor.execute(hint) cursor.execute(hint)


logger.info("Applied %d performance optimization hints for ClickZetta vector operations", len(performance_hints))
logger.info(
"Applied %d performance optimization hints for ClickZetta vector operations", len(performance_hints)
)


except Exception: except Exception:
# Catch any errors setting performance hints but continue with defaults # Catch any errors setting performance hints but continue with defaults
logger.info("Created vector index: %s", index_name) logger.info("Created vector index: %s", index_name)
except (RuntimeError, ValueError) as e: except (RuntimeError, ValueError) as e:
error_msg = str(e).lower() error_msg = str(e).lower()
if ("already exists" in error_msg or
"already has index" in error_msg or
"with the same type" in error_msg):
if "already exists" in error_msg or "already has index" in error_msg or "with the same type" in error_msg:
logger.info("Vector index already exists: %s", e) logger.info("Vector index already exists: %s", e)
else: else:
logger.exception("Failed to create vector index") logger.exception("Failed to create vector index")
for idx in existing_indexes: for idx in existing_indexes:
idx_str = str(idx).lower() idx_str = str(idx).lower()
# More precise check: look for inverted index specifically on the content column # More precise check: look for inverted index specifically on the content column
if ("inverted" in idx_str and
Field.CONTENT_KEY.value.lower() in idx_str and
(index_name.lower() in idx_str or f"idx_{self._table_name}_text" in idx_str)):
if (
"inverted" in idx_str
and Field.CONTENT_KEY.value.lower() in idx_str
and (index_name.lower() in idx_str or f"idx_{self._table_name}_text" in idx_str)
):
logger.info("Inverted index already exists on column %s: %s", Field.CONTENT_KEY.value, idx) logger.info("Inverted index already exists on column %s: %s", Field.CONTENT_KEY.value, idx)
return return
except (RuntimeError, ValueError) as e: except (RuntimeError, ValueError) as e:
except (RuntimeError, ValueError) as e: except (RuntimeError, ValueError) as e:
error_msg = str(e).lower() error_msg = str(e).lower()
# Handle ClickZetta specific error messages # Handle ClickZetta specific error messages
if (("already exists" in error_msg or
"already has index" in error_msg or
"with the same type" in error_msg or
"cannot create inverted index" in error_msg) and
"already has index" in error_msg):
if (
"already exists" in error_msg
or "already has index" in error_msg
or "with the same type" in error_msg
or "cannot create inverted index" in error_msg
) and "already has index" in error_msg:
logger.info("Inverted index already exists on column %s", Field.CONTENT_KEY.value) logger.info("Inverted index already exists on column %s", Field.CONTENT_KEY.value)
# Try to get the existing index name for logging # Try to get the existing index name for logging
try: try:
logger.warning("Failed to create inverted index: %s", e) logger.warning("Failed to create inverted index: %s", e)
# Continue without inverted index - full-text search will fall back to LIKE # Continue without inverted index - full-text search will fall back to LIKE



def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
"""Add documents with embeddings to the collection.""" """Add documents with embeddings to the collection."""
if not documents: if not documents:
total_batches = (len(documents) + batch_size - 1) // batch_size total_batches = (len(documents) + batch_size - 1) // batch_size


for i in range(0, len(documents), batch_size): for i in range(0, len(documents), batch_size):
batch_docs = documents[i:i + batch_size]
batch_embeddings = embeddings[i:i + batch_size]
batch_docs = documents[i : i + batch_size]
batch_embeddings = embeddings[i : i + batch_size]


# Execute batch insert through write queue # Execute batch insert through write queue
self._execute_write(self._insert_batch, batch_docs, batch_embeddings, i, batch_size, total_batches) self._execute_write(self._insert_batch, batch_docs, batch_embeddings, i, batch_size, total_batches)


def _insert_batch(self, batch_docs: list[Document], batch_embeddings: list[list[float]],
batch_index: int, batch_size: int, total_batches: int):
def _insert_batch(
self,
batch_docs: list[Document],
batch_embeddings: list[list[float]],
batch_index: int,
batch_size: int,
total_batches: int,
):
"""Insert a batch of documents using parameterized queries (executed in write worker thread).""" """Insert a batch of documents using parameterized queries (executed in write worker thread)."""
if not batch_docs or not batch_embeddings: if not batch_docs or not batch_embeddings:
logger.warning("Empty batch provided, skipping insertion") logger.warning("Empty batch provided, skipping insertion")


# According to ClickZetta docs, vector should be formatted as array string # According to ClickZetta docs, vector should be formatted as array string
# for external systems: '[1.0, 2.0, 3.0]' # for external systems: '[1.0, 2.0, 3.0]'
vector_str = '[' + ','.join(map(str, embedding)) + ']'
vector_str = "[" + ",".join(map(str, embedding)) + "]"
data_rows.append([doc_id, content, metadata_json, vector_str]) data_rows.append([doc_id, content, metadata_json, vector_str])


# Check if we have any valid data to insert # Check if we have any valid data to insert


cursor.executemany(insert_sql, data_rows) cursor.executemany(insert_sql, data_rows)
logger.info( logger.info(
f"Inserted batch {batch_index // batch_size + 1}/{total_batches} "
f"({len(data_rows)} valid docs using parameterized query with VECTOR({vector_dimension}) cast)"
"Inserted batch %d/%d (%d valid docs using parameterized query with VECTOR(%d) cast)",
batch_index // batch_size + 1,
total_batches,
len(data_rows),
vector_dimension,
) )
except (RuntimeError, ValueError, TypeError, ConnectionError) as e: except (RuntimeError, ValueError, TypeError, ConnectionError) as e:
logger.exception("Parameterized SQL execution failed for %d documents: %s", len(data_rows), e)
logger.exception("Parameterized SQL execution failed for %d documents", len(data_rows))
logger.exception("SQL template: %s", insert_sql) logger.exception("SQL template: %s", insert_sql)
logger.exception("Sample data row: %s", data_rows[0] if data_rows else 'None')
logger.exception("Sample data row: %s", data_rows[0] if data_rows else "None")
raise raise


def text_exists(self, id: str) -> bool: def text_exists(self, id: str) -> bool:
connection = self._ensure_connection() connection = self._ensure_connection()
with connection.cursor() as cursor: with connection.cursor() as cursor:
cursor.execute( cursor.execute(
f"SELECT COUNT(*) FROM {self._config.schema_name}.{self._table_name} WHERE id = ?",
[safe_id]
f"SELECT COUNT(*) FROM {self._config.schema_name}.{self._table_name} WHERE id = ?", [safe_id]
) )
result = cursor.fetchone() result = cursor.fetchone()
return result[0] > 0 if result else False return result[0] > 0 if result else False
# Using JSON path to filter with parameterized query # Using JSON path to filter with parameterized query
# Note: JSON path requires literal key name, cannot be parameterized # Note: JSON path requires literal key name, cannot be parameterized
# Use json_extract_string function for ClickZetta compatibility # Use json_extract_string function for ClickZetta compatibility
sql = (f"DELETE FROM {self._config.schema_name}.{self._table_name} "
f"WHERE json_extract_string({Field.METADATA_KEY.value}, '$.{key}') = ?")
sql = (
f"DELETE FROM {self._config.schema_name}.{self._table_name} "
f"WHERE json_extract_string({Field.METADATA_KEY.value}, '$.{key}') = ?"
)
cursor.execute(sql, [value]) cursor.execute(sql, [value])


def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
distance_func = "COSINE_DISTANCE" distance_func = "COSINE_DISTANCE"
if score_threshold > 0: if score_threshold > 0:
query_vector_str = f"CAST('[{self._format_vector_simple(query_vector)}]' AS VECTOR({vector_dimension}))" query_vector_str = f"CAST('[{self._format_vector_simple(query_vector)}]' AS VECTOR({vector_dimension}))"
filter_clauses.append(f"{distance_func}({Field.VECTOR.value}, "
f"{query_vector_str}) < {2 - score_threshold}")
filter_clauses.append(
f"{distance_func}({Field.VECTOR.value}, {query_vector_str}) < {2 - score_threshold}"
)
else: else:
# For L2 distance, smaller is better # For L2 distance, smaller is better
distance_func = "L2_DISTANCE" distance_func = "L2_DISTANCE"
if score_threshold > 0: if score_threshold > 0:
query_vector_str = f"CAST('[{self._format_vector_simple(query_vector)}]' AS VECTOR({vector_dimension}))" query_vector_str = f"CAST('[{self._format_vector_simple(query_vector)}]' AS VECTOR({vector_dimension}))"
filter_clauses.append(f"{distance_func}({Field.VECTOR.value}, "
f"{query_vector_str}) < {score_threshold}")
filter_clauses.append(f"{distance_func}({Field.VECTOR.value}, {query_vector_str}) < {score_threshold}")


where_clause = " AND ".join(filter_clauses) if filter_clauses else "1=1" where_clause = " AND ".join(filter_clauses) if filter_clauses else "1=1"


with connection.cursor() as cursor: with connection.cursor() as cursor:
# Use hints parameter for vector search optimization # Use hints parameter for vector search optimization
search_hints = { search_hints = {
'hints': {
'sdk.job.timeout': 60, # Increase timeout for vector search
'cz.sql.job.fast.mode': True,
'cz.storage.parquet.vector.index.read.memory.cache': True
"hints": {
"sdk.job.timeout": 60, # Increase timeout for vector search
"cz.sql.job.fast.mode": True,
"cz.storage.parquet.vector.index.read.memory.cache": True,
} }
} }
cursor.execute(search_sql, parameters=search_hints) cursor.execute(search_sql, parameters=search_hints)
else: else:
metadata = {} metadata = {}
except (json.JSONDecodeError, TypeError) as e: except (json.JSONDecodeError, TypeError) as e:
logger.error("JSON parsing failed: %s", e)
logger.exception("JSON parsing failed")
# Fallback: extract document_id with regex # Fallback: extract document_id with regex
import re import re
doc_id_match = re.search(r'"document_id":\s*"([^"]+)"', str(row[2] or ''))

doc_id_match = re.search(r'"document_id":\s*"([^"]+)"', str(row[2] or ""))
metadata = {"document_id": doc_id_match.group(1)} if doc_id_match else {} metadata = {"document_id": doc_id_match.group(1)} if doc_id_match else {}


# Ensure required fields are set # Ensure required fields are set
try: try:
# Use hints parameter for full-text search optimization # Use hints parameter for full-text search optimization
fulltext_hints = { fulltext_hints = {
'hints': {
'sdk.job.timeout': 30, # Timeout for full-text search
'cz.sql.job.fast.mode': True,
'cz.sql.index.prewhere.enabled': True
"hints": {
"sdk.job.timeout": 30, # Timeout for full-text search
"cz.sql.job.fast.mode": True,
"cz.sql.index.prewhere.enabled": True,
} }
} }
cursor.execute(search_sql, parameters=fulltext_hints) cursor.execute(search_sql, parameters=fulltext_hints)
else: else:
metadata = {} metadata = {}
except (json.JSONDecodeError, TypeError) as e: except (json.JSONDecodeError, TypeError) as e:
logger.error("JSON parsing failed: %s", e)
logger.exception("JSON parsing failed")
# Fallback: extract document_id with regex # Fallback: extract document_id with regex
import re import re
doc_id_match = re.search(r'"document_id":\s*"([^"]+)"', str(row[2] or ''))

doc_id_match = re.search(r'"document_id":\s*"([^"]+)"', str(row[2] or ""))
metadata = {"document_id": doc_id_match.group(1)} if doc_id_match else {} metadata = {"document_id": doc_id_match.group(1)} if doc_id_match else {}


# Ensure required fields are set # Ensure required fields are set
with connection.cursor() as cursor: with connection.cursor() as cursor:
# Use hints parameter for LIKE search optimization # Use hints parameter for LIKE search optimization
like_hints = { like_hints = {
'hints': {
'sdk.job.timeout': 20, # Timeout for LIKE search
'cz.sql.job.fast.mode': True
"hints": {
"sdk.job.timeout": 20, # Timeout for LIKE search
"cz.sql.job.fast.mode": True,
} }
} }
cursor.execute(search_sql, parameters=like_hints) cursor.execute(search_sql, parameters=like_hints)
else: else:
metadata = {} metadata = {}
except (json.JSONDecodeError, TypeError) as e: except (json.JSONDecodeError, TypeError) as e:
logger.error("JSON parsing failed: %s", e)
logger.exception("JSON parsing failed")
# Fallback: extract document_id with regex # Fallback: extract document_id with regex
import re import re
doc_id_match = re.search(r'"document_id":\s*"([^"]+)"', str(row[2] or ''))

doc_id_match = re.search(r'"document_id":\s*"([^"]+)"', str(row[2] or ""))
metadata = {"document_id": doc_id_match.group(1)} if doc_id_match else {} metadata = {"document_id": doc_id_match.group(1)} if doc_id_match else {}


# Ensure required fields are set # Ensure required fields are set
with connection.cursor() as cursor: with connection.cursor() as cursor:
cursor.execute(f"DROP TABLE IF EXISTS {self._config.schema_name}.{self._table_name}") cursor.execute(f"DROP TABLE IF EXISTS {self._config.schema_name}.{self._table_name}")



def _format_vector_simple(self, vector: list[float]) -> str: def _format_vector_simple(self, vector: list[float]) -> str:
"""Simple vector formatting for SQL queries.""" """Simple vector formatting for SQL queries."""
return ','.join(map(str, vector))
return ",".join(map(str, vector))


def _safe_doc_id(self, doc_id: str) -> str: def _safe_doc_id(self, doc_id: str) -> str:
"""Ensure doc_id is safe for SQL and doesn't contain special characters.""" """Ensure doc_id is safe for SQL and doesn't contain special characters."""
# Remove or replace potentially problematic characters # Remove or replace potentially problematic characters
safe_id = str(doc_id) safe_id = str(doc_id)
# Only allow alphanumeric, hyphens, underscores # Only allow alphanumeric, hyphens, underscores
safe_id = ''.join(c for c in safe_id if c.isalnum() or c in '-_')
safe_id = "".join(c for c in safe_id if c.isalnum() or c in "-_")
if not safe_id: # If all characters were removed if not safe_id: # If all characters were removed
return str(uuid.uuid4()) return str(uuid.uuid4())
return safe_id[:255] # Limit length return safe_id[:255] # Limit length





class ClickzettaVectorFactory(AbstractVectorFactory): class ClickzettaVectorFactory(AbstractVectorFactory):
"""Factory for creating Clickzetta vector instances.""" """Factory for creating Clickzetta vector instances."""


collection_name = Dataset.gen_collection_name_by_id(dataset.id).lower() collection_name = Dataset.gen_collection_name_by_id(dataset.id).lower()


return ClickzettaVector(collection_name=collection_name, config=config) return ClickzettaVector(collection_name=collection_name, config=config)


+ 7
- 5
api/tasks/clean_dataset_task.py 파일 보기

documents = db.session.query(Document).where(Document.dataset_id == dataset_id).all() documents = db.session.query(Document).where(Document.dataset_id == dataset_id).all()
segments = db.session.query(DocumentSegment).where(DocumentSegment.dataset_id == dataset_id).all() segments = db.session.query(DocumentSegment).where(DocumentSegment.dataset_id == dataset_id).all()


# Fix: Always clean vector database resources regardless of document existence
# This ensures all 33 vector databases properly drop tables/collections/indices
if doc_form is None:
raise ValueError("Index type must be specified.")
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
index_processor.clean(dataset, None, with_keywords=True, delete_child_chunks=True)

if documents is None or len(documents) == 0: if documents is None or len(documents) == 0:
logging.info(click.style(f"No documents found for dataset: {dataset_id}", fg="green")) logging.info(click.style(f"No documents found for dataset: {dataset_id}", fg="green"))
else: else:
logging.info(click.style(f"Cleaning documents for dataset: {dataset_id}", fg="green")) logging.info(click.style(f"Cleaning documents for dataset: {dataset_id}", fg="green"))
# Specify the index type before initializing the index processor
if doc_form is None:
raise ValueError("Index type must be specified.")
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
index_processor.clean(dataset, None, with_keywords=True, delete_child_chunks=True)


for document in documents: for document in documents:
db.session.delete(document) db.session.delete(document)

+ 10
- 23
api/tests/integration_tests/vdb/clickzetta/test_clickzetta.py 파일 보기

) )


with setup_mock_redis(): with setup_mock_redis():
vector = ClickzettaVector(
collection_name="test_collection_" + str(os.getpid()),
config=config
)
vector = ClickzettaVector(collection_name="test_collection_" + str(os.getpid()), config=config)


yield vector yield vector


"category": "technical" if i % 2 == 0 else "general", "category": "technical" if i % 2 == 0 else "general",
"document_id": f"doc_{i // 3}", # Group documents "document_id": f"doc_{i // 3}", # Group documents
"importance": i, "importance": i,
}
},
) )
documents.append(doc) documents.append(doc)
# Create varied embeddings # Create varied embeddings


# Test vector search with document filter # Test vector search with document filter
query_vector = [0.5, 1.0, 1.5, 2.0] query_vector = [0.5, 1.0, 1.5, 2.0]
results = vector_store.search_by_vector(
query_vector,
top_k=5,
document_ids_filter=["doc_0", "doc_1"]
)
results = vector_store.search_by_vector(query_vector, top_k=5, document_ids_filter=["doc_0", "doc_1"])
assert len(results) > 0 assert len(results) > 0
# All results should belong to doc_0 or doc_1 groups # All results should belong to doc_0 or doc_1 groups
for result in results: for result in results:
assert result.metadata["document_id"] in ["doc_0", "doc_1"] assert result.metadata["document_id"] in ["doc_0", "doc_1"]


# Test score threshold # Test score threshold
results = vector_store.search_by_vector(
query_vector,
top_k=10,
score_threshold=0.5
)
results = vector_store.search_by_vector(query_vector, top_k=10, score_threshold=0.5)
# Check that all results have a score above threshold # Check that all results have a score above threshold
for result in results: for result in results:
assert result.metadata.get("score", 0) >= 0.5 assert result.metadata.get("score", 0) >= 0.5
for i in range(batch_size): for i in range(batch_size):
doc = Document( doc = Document(
page_content=f"Batch document {i}: This is a test document for batch processing.", page_content=f"Batch document {i}: This is a test document for batch processing.",
metadata={"doc_id": f"batch_doc_{i}", "batch": "test_batch"}
metadata={"doc_id": f"batch_doc_{i}", "batch": "test_batch"},
) )
documents.append(doc) documents.append(doc)
embeddings.append([0.1 * (i % 10), 0.2 * (i % 10), 0.3 * (i % 10), 0.4 * (i % 10)]) embeddings.append([0.1 * (i % 10), 0.2 * (i % 10), 0.3 * (i % 10), 0.4 * (i % 10)])
# Test special characters in content # Test special characters in content
special_doc = Document( special_doc = Document(
page_content="Special chars: 'quotes', \"double\", \\backslash, \n newline", page_content="Special chars: 'quotes', \"double\", \\backslash, \n newline",
metadata={"doc_id": "special_doc", "test": "edge_case"}
metadata={"doc_id": "special_doc", "test": "edge_case"},
) )
embeddings = [[0.1, 0.2, 0.3, 0.4]] embeddings = [[0.1, 0.2, 0.3, 0.4]]


# Prepare documents with various language content # Prepare documents with various language content
documents = [ documents = [
Document( Document(
page_content="云器科技提供强大的Lakehouse解决方案",
metadata={"doc_id": "cn_doc_1", "lang": "chinese"}
page_content="云器科技提供强大的Lakehouse解决方案", metadata={"doc_id": "cn_doc_1", "lang": "chinese"}
), ),
Document( Document(
page_content="Clickzetta provides powerful Lakehouse solutions", page_content="Clickzetta provides powerful Lakehouse solutions",
metadata={"doc_id": "en_doc_1", "lang": "english"}
metadata={"doc_id": "en_doc_1", "lang": "english"},
), ),
Document( Document(
page_content="Lakehouse是现代数据架构的重要组成部分",
metadata={"doc_id": "cn_doc_2", "lang": "chinese"}
page_content="Lakehouse是现代数据架构的重要组成部分", metadata={"doc_id": "cn_doc_2", "lang": "chinese"}
), ),
Document( Document(
page_content="Modern data architecture includes Lakehouse technology", page_content="Modern data architecture includes Lakehouse technology",
metadata={"doc_id": "en_doc_2", "lang": "english"}
metadata={"doc_id": "en_doc_2", "lang": "english"},
), ),
] ]



+ 11
- 11
api/tests/integration_tests/vdb/clickzetta/test_docker_integration.py 파일 보기

""" """
Test Clickzetta integration in Docker environment Test Clickzetta integration in Docker environment
""" """

import os import os
import time import time


service=os.getenv("CLICKZETTA_SERVICE", "api.clickzetta.com"), service=os.getenv("CLICKZETTA_SERVICE", "api.clickzetta.com"),
workspace=os.getenv("CLICKZETTA_WORKSPACE", "test_workspace"), workspace=os.getenv("CLICKZETTA_WORKSPACE", "test_workspace"),
vcluster=os.getenv("CLICKZETTA_VCLUSTER", "default"), vcluster=os.getenv("CLICKZETTA_VCLUSTER", "default"),
database=os.getenv("CLICKZETTA_SCHEMA", "dify")
database=os.getenv("CLICKZETTA_SCHEMA", "dify"),
) )


with conn.cursor() as cursor: with conn.cursor() as cursor:


# Check if test collection exists # Check if test collection exists
test_collection = "collection_test_dataset" test_collection = "collection_test_dataset"
if test_collection in [t[1] for t in tables if t[0] == 'dify']:
if test_collection in [t[1] for t in tables if t[0] == "dify"]:
cursor.execute(f"DESCRIBE dify.{test_collection}") cursor.execute(f"DESCRIBE dify.{test_collection}")
columns = cursor.fetchall() columns = cursor.fetchall()
print(f"✓ Table structure for {test_collection}:") print(f"✓ Table structure for {test_collection}:")
print(f"✗ Connection test failed: {e}") print(f"✗ Connection test failed: {e}")
return False return False



def test_dify_api(): def test_dify_api():
"""Test Dify API with Clickzetta backend""" """Test Dify API with Clickzetta backend"""
print("\n=== Testing Dify API ===") print("\n=== Testing Dify API ===")
print(f"✗ API test failed: {e}") print(f"✗ API test failed: {e}")
return False return False



def verify_table_structure(): def verify_table_structure():
"""Verify the table structure meets Dify requirements""" """Verify the table structure meets Dify requirements"""
print("\n=== Verifying Table Structure ===") print("\n=== Verifying Table Structure ===")
"id": "VARCHAR", "id": "VARCHAR",
"page_content": "VARCHAR", "page_content": "VARCHAR",
"metadata": "VARCHAR", # JSON stored as VARCHAR in Clickzetta "metadata": "VARCHAR", # JSON stored as VARCHAR in Clickzetta
"vector": "ARRAY<FLOAT>"
"vector": "ARRAY<FLOAT>",
} }


expected_metadata_fields = [
"doc_id",
"doc_hash",
"document_id",
"dataset_id"
]
expected_metadata_fields = ["doc_id", "doc_hash", "document_id", "dataset_id"]


print("✓ Expected table structure:") print("✓ Expected table structure:")
for col, dtype in expected_columns.items(): for col, dtype in expected_columns.items():


return True return True



def main(): def main():
"""Run all tests""" """Run all tests"""
print("Starting Clickzetta integration tests for Dify Docker\n") print("Starting Clickzetta integration tests for Dify Docker\n")
results.append((test_name, False)) results.append((test_name, False))


# Summary # Summary
print("\n" + "="*50)
print("\n" + "=" * 50)
print("Test Summary:") print("Test Summary:")
print("="*50)
print("=" * 50)


passed = sum(1 for _, success in results if success) passed = sum(1 for _, success in results if success)
total = len(results) total = len(results)
print("\n⚠️ Some tests failed. Please check the errors above.") print("\n⚠️ Some tests failed. Please check the errors above.")
return 1 return 1



if __name__ == "__main__": if __name__ == "__main__":
exit(main()) exit(main())

Loading…
취소
저장