|
|
|
@@ -1,5 +1,7 @@ |
|
|
|
import json |
|
|
|
from typing import Any |
|
|
|
import logging |
|
|
|
from typing import Any, Optional |
|
|
|
from urllib.parse import urlparse |
|
|
|
|
|
|
|
import requests |
|
|
|
from elasticsearch import Elasticsearch |
|
|
|
@@ -7,16 +9,20 @@ from flask import current_app |
|
|
|
from pydantic import BaseModel, model_validator |
|
|
|
|
|
|
|
from core.rag.datasource.entity.embedding import Embeddings |
|
|
|
from core.rag.datasource.vdb.field import Field |
|
|
|
from core.rag.datasource.vdb.vector_base import BaseVector |
|
|
|
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory |
|
|
|
from core.rag.datasource.vdb.vector_type import VectorType |
|
|
|
from core.rag.models.document import Document |
|
|
|
from extensions.ext_redis import redis_client |
|
|
|
from models.dataset import Dataset |
|
|
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
|
|
|
class ElasticSearchConfig(BaseModel): |
|
|
|
host: str |
|
|
|
port: str |
|
|
|
port: int |
|
|
|
username: str |
|
|
|
password: str |
|
|
|
|
|
|
|
@@ -37,12 +43,19 @@ class ElasticSearchVector(BaseVector): |
|
|
|
def __init__(self, index_name: str, config: ElasticSearchConfig, attributes: list): |
|
|
|
super().__init__(index_name.lower()) |
|
|
|
self._client = self._init_client(config) |
|
|
|
self._version = self._get_version() |
|
|
|
self._check_version() |
|
|
|
self._attributes = attributes |
|
|
|
|
|
|
|
def _init_client(self, config: ElasticSearchConfig) -> Elasticsearch: |
|
|
|
try: |
|
|
|
parsed_url = urlparse(config.host) |
|
|
|
if parsed_url.scheme in ['http', 'https']: |
|
|
|
hosts = f'{config.host}:{config.port}' |
|
|
|
else: |
|
|
|
hosts = f'http://{config.host}:{config.port}' |
|
|
|
client = Elasticsearch( |
|
|
|
hosts=f'{config.host}:{config.port}', |
|
|
|
hosts=hosts, |
|
|
|
basic_auth=(config.username, config.password), |
|
|
|
request_timeout=100000, |
|
|
|
retry_on_timeout=True, |
|
|
|
@@ -53,42 +66,27 @@ class ElasticSearchVector(BaseVector): |
|
|
|
|
|
|
|
return client |
|
|
|
|
|
|
|
def _get_version(self) -> str: |
|
|
|
info = self._client.info() |
|
|
|
return info['version']['number'] |
|
|
|
|
|
|
|
def _check_version(self): |
|
|
|
if self._version < '8.0.0': |
|
|
|
raise ValueError("Elasticsearch vector database version must be greater than 8.0.0") |
|
|
|
|
|
|
|
def get_type(self) -> str: |
|
|
|
return 'elasticsearch' |
|
|
|
|
|
|
|
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): |
|
|
|
uuids = self._get_uuids(documents) |
|
|
|
texts = [d.page_content for d in documents] |
|
|
|
metadatas = [d.metadata for d in documents] |
|
|
|
|
|
|
|
if not self._client.indices.exists(index=self._collection_name): |
|
|
|
dim = len(embeddings[0]) |
|
|
|
mapping = { |
|
|
|
"properties": { |
|
|
|
"text": { |
|
|
|
"type": "text" |
|
|
|
}, |
|
|
|
"vector": { |
|
|
|
"type": "dense_vector", |
|
|
|
"index": True, |
|
|
|
"dims": dim, |
|
|
|
"similarity": "l2_norm" |
|
|
|
}, |
|
|
|
} |
|
|
|
} |
|
|
|
self._client.indices.create(index=self._collection_name, mappings=mapping) |
|
|
|
|
|
|
|
added_ids = [] |
|
|
|
for i, text in enumerate(texts): |
|
|
|
for i in range(len(documents)): |
|
|
|
self._client.index(index=self._collection_name, |
|
|
|
id=uuids[i], |
|
|
|
document={ |
|
|
|
"text": text, |
|
|
|
"vector": embeddings[i] if embeddings[i] else None, |
|
|
|
"metadata": metadatas[i] if metadatas[i] else {}, |
|
|
|
Field.CONTENT_KEY.value: documents[i].page_content, |
|
|
|
Field.VECTOR.value: embeddings[i] if embeddings[i] else None, |
|
|
|
Field.METADATA_KEY.value: documents[i].metadata if documents[i].metadata else {} |
|
|
|
}) |
|
|
|
added_ids.append(uuids[i]) |
|
|
|
|
|
|
|
self._client.indices.refresh(index=self._collection_name) |
|
|
|
return uuids |
|
|
|
|
|
|
|
@@ -116,28 +114,21 @@ class ElasticSearchVector(BaseVector): |
|
|
|
self._client.indices.delete(index=self._collection_name) |
|
|
|
|
|
|
|
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: |
|
|
|
query_str = { |
|
|
|
"query": { |
|
|
|
"script_score": { |
|
|
|
"query": { |
|
|
|
"match_all": {} |
|
|
|
}, |
|
|
|
"script": { |
|
|
|
"source": "cosineSimilarity(params.query_vector, 'vector') + 1.0", |
|
|
|
"params": { |
|
|
|
"query_vector": query_vector |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
top_k = kwargs.get("top_k", 10) |
|
|
|
knn = { |
|
|
|
"field": Field.VECTOR.value, |
|
|
|
"query_vector": query_vector, |
|
|
|
"k": top_k |
|
|
|
} |
|
|
|
|
|
|
|
results = self._client.search(index=self._collection_name, body=query_str) |
|
|
|
results = self._client.search(index=self._collection_name, knn=knn, size=top_k) |
|
|
|
|
|
|
|
docs_and_scores = [] |
|
|
|
for hit in results['hits']['hits']: |
|
|
|
docs_and_scores.append( |
|
|
|
(Document(page_content=hit['_source']['text'], metadata=hit['_source']['metadata']), hit['_score'])) |
|
|
|
(Document(page_content=hit['_source'][Field.CONTENT_KEY.value], |
|
|
|
vector=hit['_source'][Field.VECTOR.value], |
|
|
|
metadata=hit['_source'][Field.METADATA_KEY.value]), hit['_score'])) |
|
|
|
|
|
|
|
docs = [] |
|
|
|
for doc, score in docs_and_scores: |
|
|
|
@@ -146,25 +137,61 @@ class ElasticSearchVector(BaseVector): |
|
|
|
doc.metadata['score'] = score |
|
|
|
docs.append(doc) |
|
|
|
|
|
|
|
# Sort the documents by score in descending order |
|
|
|
docs = sorted(docs, key=lambda x: x.metadata['score'], reverse=True) |
|
|
|
|
|
|
|
return docs |
|
|
|
|
|
|
|
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: |
|
|
|
query_str = { |
|
|
|
"match": { |
|
|
|
"text": query |
|
|
|
Field.CONTENT_KEY.value: query |
|
|
|
} |
|
|
|
} |
|
|
|
results = self._client.search(index=self._collection_name, query=query_str) |
|
|
|
docs = [] |
|
|
|
for hit in results['hits']['hits']: |
|
|
|
docs.append(Document(page_content=hit['_source']['text'], metadata=hit['_source']['metadata'])) |
|
|
|
docs.append(Document( |
|
|
|
page_content=hit['_source'][Field.CONTENT_KEY.value], |
|
|
|
vector=hit['_source'][Field.VECTOR.value], |
|
|
|
metadata=hit['_source'][Field.METADATA_KEY.value], |
|
|
|
)) |
|
|
|
|
|
|
|
return docs |
|
|
|
|
|
|
|
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): |
|
|
|
return self.add_texts(texts, embeddings, **kwargs) |
|
|
|
metadatas = [d.metadata for d in texts] |
|
|
|
self.create_collection(embeddings, metadatas) |
|
|
|
self.add_texts(texts, embeddings, **kwargs) |
|
|
|
|
|
|
|
def create_collection( |
|
|
|
self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None |
|
|
|
): |
|
|
|
lock_name = f'vector_indexing_lock_{self._collection_name}' |
|
|
|
with redis_client.lock(lock_name, timeout=20): |
|
|
|
collection_exist_cache_key = f'vector_indexing_{self._collection_name}' |
|
|
|
if redis_client.get(collection_exist_cache_key): |
|
|
|
logger.info(f"Collection {self._collection_name} already exists.") |
|
|
|
return |
|
|
|
|
|
|
|
if not self._client.indices.exists(index=self._collection_name): |
|
|
|
dim = len(embeddings[0]) |
|
|
|
mappings = { |
|
|
|
"properties": { |
|
|
|
Field.CONTENT_KEY.value: {"type": "text"}, |
|
|
|
Field.VECTOR.value: { # Make sure the dimension is correct here |
|
|
|
"type": "dense_vector", |
|
|
|
"dims": dim, |
|
|
|
"similarity": "cosine" |
|
|
|
}, |
|
|
|
Field.METADATA_KEY.value: { |
|
|
|
"type": "object", |
|
|
|
"properties": { |
|
|
|
"doc_id": {"type": "keyword"} # Map doc_id to keyword type |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
self._client.indices.create(index=self._collection_name, mappings=mappings) |
|
|
|
|
|
|
|
redis_client.set(collection_exist_cache_key, 1, ex=3600) |
|
|
|
|
|
|
|
|
|
|
|
class ElasticSearchVectorFactory(AbstractVectorFactory): |