ソースを参照

feat: rewrite Elasticsearch index and search code to achieve Elasticsearch vector and full-text search (#7641)

Co-authored-by: haokai <haokai@shuwen.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
Co-authored-by: Bowen Liang <bowenliang@apache.org>
Co-authored-by: wellCh4n <wellCh4n@foxmail.com>
tags/0.7.3
Kenn 1年前
コミット
122ce41020
コミッターのメールアドレスに関連付けられたアカウントが存在しません

+ 2
- 0
api/configs/middleware/__init__.py ファイルの表示

from configs.middleware.storage.tencent_cos_storage_config import TencentCloudCOSStorageConfig from configs.middleware.storage.tencent_cos_storage_config import TencentCloudCOSStorageConfig
from configs.middleware.vdb.analyticdb_config import AnalyticdbConfig from configs.middleware.vdb.analyticdb_config import AnalyticdbConfig
from configs.middleware.vdb.chroma_config import ChromaConfig from configs.middleware.vdb.chroma_config import ChromaConfig
from configs.middleware.vdb.elasticsearch_config import ElasticsearchConfig
from configs.middleware.vdb.milvus_config import MilvusConfig from configs.middleware.vdb.milvus_config import MilvusConfig
from configs.middleware.vdb.myscale_config import MyScaleConfig from configs.middleware.vdb.myscale_config import MyScaleConfig
from configs.middleware.vdb.opensearch_config import OpenSearchConfig from configs.middleware.vdb.opensearch_config import OpenSearchConfig
TencentVectorDBConfig, TencentVectorDBConfig,
TiDBVectorConfig, TiDBVectorConfig,
WeaviateConfig, WeaviateConfig,
ElasticsearchConfig,
): ):
pass pass

+ 30
- 0
api/configs/middleware/vdb/elasticsearch_config.py ファイルの表示

from typing import Optional

from pydantic import Field, PositiveInt
from pydantic_settings import BaseSettings


class ElasticsearchConfig(BaseSettings):
"""
Elasticsearch configs
"""

ELASTICSEARCH_HOST: Optional[str] = Field(
description="Elasticsearch host",
default="127.0.0.1",
)

ELASTICSEARCH_PORT: PositiveInt = Field(
description="Elasticsearch port",
default=9200,
)

ELASTICSEARCH_USERNAME: Optional[str] = Field(
description="Elasticsearch username",
default="elastic",
)

ELASTICSEARCH_PASSWORD: Optional[str] = Field(
description="Elasticsearch password",
default="elastic",
)

+ 79
- 52
api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py ファイルの表示

import json import json
from typing import Any
import logging
from typing import Any, Optional
from urllib.parse import urlparse


import requests import requests
from elasticsearch import Elasticsearch from elasticsearch import Elasticsearch
from pydantic import BaseModel, model_validator from pydantic import BaseModel, model_validator


from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.entity.embedding import Embeddings
from core.rag.datasource.vdb.field import Field
from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.models.document import Document from core.rag.models.document import Document
from extensions.ext_redis import redis_client
from models.dataset import Dataset from models.dataset import Dataset


logger = logging.getLogger(__name__)



class ElasticSearchConfig(BaseModel): class ElasticSearchConfig(BaseModel):
host: str host: str
port: str
port: int
username: str username: str
password: str password: str


def __init__(self, index_name: str, config: ElasticSearchConfig, attributes: list): def __init__(self, index_name: str, config: ElasticSearchConfig, attributes: list):
super().__init__(index_name.lower()) super().__init__(index_name.lower())
self._client = self._init_client(config) self._client = self._init_client(config)
self._version = self._get_version()
self._check_version()
self._attributes = attributes self._attributes = attributes


def _init_client(self, config: ElasticSearchConfig) -> Elasticsearch: def _init_client(self, config: ElasticSearchConfig) -> Elasticsearch:
try: try:
parsed_url = urlparse(config.host)
if parsed_url.scheme in ['http', 'https']:
hosts = f'{config.host}:{config.port}'
else:
hosts = f'http://{config.host}:{config.port}'
client = Elasticsearch( client = Elasticsearch(
hosts=f'{config.host}:{config.port}',
hosts=hosts,
basic_auth=(config.username, config.password), basic_auth=(config.username, config.password),
request_timeout=100000, request_timeout=100000,
retry_on_timeout=True, retry_on_timeout=True,


return client return client


def _get_version(self) -> str:
info = self._client.info()
return info['version']['number']

def _check_version(self):
if self._version < '8.0.0':
raise ValueError("Elasticsearch vector database version must be greater than 8.0.0")

def get_type(self) -> str: def get_type(self) -> str:
return 'elasticsearch' return 'elasticsearch'


def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
uuids = self._get_uuids(documents) uuids = self._get_uuids(documents)
texts = [d.page_content for d in documents]
metadatas = [d.metadata for d in documents]

if not self._client.indices.exists(index=self._collection_name):
dim = len(embeddings[0])
mapping = {
"properties": {
"text": {
"type": "text"
},
"vector": {
"type": "dense_vector",
"index": True,
"dims": dim,
"similarity": "l2_norm"
},
}
}
self._client.indices.create(index=self._collection_name, mappings=mapping)

added_ids = []
for i, text in enumerate(texts):
for i in range(len(documents)):
self._client.index(index=self._collection_name, self._client.index(index=self._collection_name,
id=uuids[i], id=uuids[i],
document={ document={
"text": text,
"vector": embeddings[i] if embeddings[i] else None,
"metadata": metadatas[i] if metadatas[i] else {},
Field.CONTENT_KEY.value: documents[i].page_content,
Field.VECTOR.value: embeddings[i] if embeddings[i] else None,
Field.METADATA_KEY.value: documents[i].metadata if documents[i].metadata else {}
}) })
added_ids.append(uuids[i])

self._client.indices.refresh(index=self._collection_name) self._client.indices.refresh(index=self._collection_name)
return uuids return uuids


self._client.indices.delete(index=self._collection_name) self._client.indices.delete(index=self._collection_name)


def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
query_str = {
"query": {
"script_score": {
"query": {
"match_all": {}
},
"script": {
"source": "cosineSimilarity(params.query_vector, 'vector') + 1.0",
"params": {
"query_vector": query_vector
}
}
}
}
top_k = kwargs.get("top_k", 10)
knn = {
"field": Field.VECTOR.value,
"query_vector": query_vector,
"k": top_k
} }


results = self._client.search(index=self._collection_name, body=query_str)
results = self._client.search(index=self._collection_name, knn=knn, size=top_k)


docs_and_scores = [] docs_and_scores = []
for hit in results['hits']['hits']: for hit in results['hits']['hits']:
docs_and_scores.append( docs_and_scores.append(
(Document(page_content=hit['_source']['text'], metadata=hit['_source']['metadata']), hit['_score']))
(Document(page_content=hit['_source'][Field.CONTENT_KEY.value],
vector=hit['_source'][Field.VECTOR.value],
metadata=hit['_source'][Field.METADATA_KEY.value]), hit['_score']))


docs = [] docs = []
for doc, score in docs_and_scores: for doc, score in docs_and_scores:
doc.metadata['score'] = score doc.metadata['score'] = score
docs.append(doc) docs.append(doc)


# Sort the documents by score in descending order
docs = sorted(docs, key=lambda x: x.metadata['score'], reverse=True)

return docs return docs

def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
query_str = { query_str = {
"match": { "match": {
"text": query
Field.CONTENT_KEY.value: query
} }
} }
results = self._client.search(index=self._collection_name, query=query_str) results = self._client.search(index=self._collection_name, query=query_str)
docs = [] docs = []
for hit in results['hits']['hits']: for hit in results['hits']['hits']:
docs.append(Document(page_content=hit['_source']['text'], metadata=hit['_source']['metadata']))
docs.append(Document(
page_content=hit['_source'][Field.CONTENT_KEY.value],
vector=hit['_source'][Field.VECTOR.value],
metadata=hit['_source'][Field.METADATA_KEY.value],
))


return docs return docs


def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
return self.add_texts(texts, embeddings, **kwargs)
metadatas = [d.metadata for d in texts]
self.create_collection(embeddings, metadatas)
self.add_texts(texts, embeddings, **kwargs)

def create_collection(
self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None
):
lock_name = f'vector_indexing_lock_{self._collection_name}'
with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = f'vector_indexing_{self._collection_name}'
if redis_client.get(collection_exist_cache_key):
logger.info(f"Collection {self._collection_name} already exists.")
return

if not self._client.indices.exists(index=self._collection_name):
dim = len(embeddings[0])
mappings = {
"properties": {
Field.CONTENT_KEY.value: {"type": "text"},
Field.VECTOR.value: { # Make sure the dimension is correct here
"type": "dense_vector",
"dims": dim,
"similarity": "cosine"
},
Field.METADATA_KEY.value: {
"type": "object",
"properties": {
"doc_id": {"type": "keyword"} # Map doc_id to keyword type
}
}
}
}
self._client.indices.create(index=self._collection_name, mappings=mappings)

redis_client.set(collection_exist_cache_key, 1, ex=3600)




class ElasticSearchVectorFactory(AbstractVectorFactory): class ElasticSearchVectorFactory(AbstractVectorFactory):

読み込み中…
キャンセル
保存