|
|
|
@@ -1,10 +1,12 @@ |
|
|
|
import copy |
|
|
|
import json |
|
|
|
import logging |
|
|
|
import time |
|
|
|
from typing import Any, Optional |
|
|
|
|
|
|
|
from opensearchpy import OpenSearch |
|
|
|
from pydantic import BaseModel, model_validator |
|
|
|
from tenacity import retry, stop_after_attempt, wait_exponential |
|
|
|
|
|
|
|
from configs import dify_config |
|
|
|
from core.rag.datasource.vdb.field import Field |
|
|
|
@@ -77,31 +79,74 @@ class LindormVectorStore(BaseVector): |
|
|
|
def refresh(self): |
|
|
|
self._client.indices.refresh(index=self._collection_name) |
|
|
|
|
|
|
|
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): |
|
|
|
actions = [] |
|
|
|
def add_texts( |
|
|
|
self, |
|
|
|
documents: list[Document], |
|
|
|
embeddings: list[list[float]], |
|
|
|
batch_size: int = 64, |
|
|
|
timeout: int = 60, |
|
|
|
**kwargs, |
|
|
|
): |
|
|
|
logger.info(f"Total documents to add: {len(documents)}") |
|
|
|
uuids = self._get_uuids(documents) |
|
|
|
for i in range(len(documents)): |
|
|
|
action_header = { |
|
|
|
"index": { |
|
|
|
"_index": self.collection_name.lower(), |
|
|
|
"_id": uuids[i], |
|
|
|
|
|
|
|
total_docs = len(documents) |
|
|
|
num_batches = (total_docs + batch_size - 1) // batch_size |
|
|
|
|
|
|
|
@retry( |
|
|
|
stop=stop_after_attempt(3), |
|
|
|
wait=wait_exponential(multiplier=1, min=4, max=10), |
|
|
|
) |
|
|
|
def _bulk_with_retry(actions): |
|
|
|
try: |
|
|
|
response = self._client.bulk(actions, timeout=timeout) |
|
|
|
if response["errors"]: |
|
|
|
error_items = [item for item in response["items"] if "error" in item["index"]] |
|
|
|
error_msg = f"Bulk indexing had {len(error_items)} errors" |
|
|
|
logger.exception(error_msg) |
|
|
|
raise Exception(error_msg) |
|
|
|
return response |
|
|
|
except Exception: |
|
|
|
logger.exception("Bulk indexing error") |
|
|
|
raise |
|
|
|
|
|
|
|
for batch_num in range(num_batches): |
|
|
|
start_idx = batch_num * batch_size |
|
|
|
end_idx = min((batch_num + 1) * batch_size, total_docs) |
|
|
|
|
|
|
|
actions = [] |
|
|
|
for i in range(start_idx, end_idx): |
|
|
|
action_header = { |
|
|
|
"index": { |
|
|
|
"_index": self.collection_name.lower(), |
|
|
|
"_id": uuids[i], |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
action_values: dict[str, Any] = { |
|
|
|
Field.CONTENT_KEY.value: documents[i].page_content, |
|
|
|
Field.VECTOR.value: embeddings[i], # Make sure you pass an array here |
|
|
|
Field.METADATA_KEY.value: documents[i].metadata, |
|
|
|
} |
|
|
|
if self._using_ugc: |
|
|
|
action_header["index"]["routing"] = self._routing |
|
|
|
if self._routing_field is not None: |
|
|
|
action_values[self._routing_field] = self._routing |
|
|
|
actions.append(action_header) |
|
|
|
actions.append(action_values) |
|
|
|
response = self._client.bulk(actions) |
|
|
|
if response["errors"]: |
|
|
|
for item in response["items"]: |
|
|
|
print(f"{item['index']['status']}: {item['index']['error']['type']}") |
|
|
|
action_values: dict[str, Any] = { |
|
|
|
Field.CONTENT_KEY.value: documents[i].page_content, |
|
|
|
Field.VECTOR.value: embeddings[i], |
|
|
|
Field.METADATA_KEY.value: documents[i].metadata, |
|
|
|
} |
|
|
|
if self._using_ugc: |
|
|
|
action_header["index"]["routing"] = self._routing |
|
|
|
if self._routing_field is not None: |
|
|
|
action_values[self._routing_field] = self._routing |
|
|
|
|
|
|
|
actions.append(action_header) |
|
|
|
actions.append(action_values) |
|
|
|
|
|
|
|
logger.info(f"Processing batch {batch_num + 1}/{num_batches} (documents {start_idx + 1} to {end_idx})") |
|
|
|
|
|
|
|
try: |
|
|
|
_bulk_with_retry(actions) |
|
|
|
logger.info(f"Successfully processed batch {batch_num + 1}") |
|
|
|
# simple latency to avoid too many requests in a short time |
|
|
|
if batch_num < num_batches - 1: |
|
|
|
time.sleep(1) |
|
|
|
|
|
|
|
except Exception: |
|
|
|
logger.exception(f"Failed to process batch {batch_num + 1}") |
|
|
|
raise |
|
|
|
|
|
|
|
def get_ids_by_metadata_field(self, key: str, value: str): |
|
|
|
query: dict[str, Any] = { |
|
|
|
@@ -130,7 +175,6 @@ class LindormVectorStore(BaseVector): |
|
|
|
if self._using_ugc: |
|
|
|
params["routing"] = self._routing |
|
|
|
self._client.delete(index=self._collection_name, id=id, params=params) |
|
|
|
self.refresh() |
|
|
|
else: |
|
|
|
logger.warning(f"DELETE BY ID: ID {id} does not exist in the index.") |
|
|
|
|