Przeglądaj źródła

optimize lindorm vdb add_texts (#17212)

Co-authored-by: jiangzhijie <jiangzhijie.jzj@alibaba-inc.com>
tags/1.2.0
Jiang 7 miesięcy temu
rodzic
commit
ff388fe3e6
No account linked to committer's email address
1 zmienionych plików z 68 dodań i 24 usunięć
  1. 68
    24
      api/core/rag/datasource/vdb/lindorm/lindorm_vector.py

+ 68
- 24
api/core/rag/datasource/vdb/lindorm/lindorm_vector.py Wyświetl plik

@@ -1,10 +1,12 @@
import copy
import json
import logging
import time
from typing import Any, Optional

from opensearchpy import OpenSearch
from pydantic import BaseModel, model_validator
from tenacity import retry, stop_after_attempt, wait_exponential

from configs import dify_config
from core.rag.datasource.vdb.field import Field
@@ -77,31 +79,74 @@ class LindormVectorStore(BaseVector):
def refresh(self):
self._client.indices.refresh(index=self._collection_name)

def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
actions = []
def add_texts(
self,
documents: list[Document],
embeddings: list[list[float]],
batch_size: int = 64,
timeout: int = 60,
**kwargs,
):
logger.info(f"Total documents to add: {len(documents)}")
uuids = self._get_uuids(documents)
for i in range(len(documents)):
action_header = {
"index": {
"_index": self.collection_name.lower(),
"_id": uuids[i],

total_docs = len(documents)
num_batches = (total_docs + batch_size - 1) // batch_size

@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
)
def _bulk_with_retry(actions):
try:
response = self._client.bulk(actions, timeout=timeout)
if response["errors"]:
error_items = [item for item in response["items"] if "error" in item["index"]]
error_msg = f"Bulk indexing had {len(error_items)} errors"
logger.exception(error_msg)
raise Exception(error_msg)
return response
except Exception:
logger.exception("Bulk indexing error")
raise

for batch_num in range(num_batches):
start_idx = batch_num * batch_size
end_idx = min((batch_num + 1) * batch_size, total_docs)

actions = []
for i in range(start_idx, end_idx):
action_header = {
"index": {
"_index": self.collection_name.lower(),
"_id": uuids[i],
}
}
}
action_values: dict[str, Any] = {
Field.CONTENT_KEY.value: documents[i].page_content,
Field.VECTOR.value: embeddings[i], # Make sure you pass an array here
Field.METADATA_KEY.value: documents[i].metadata,
}
if self._using_ugc:
action_header["index"]["routing"] = self._routing
if self._routing_field is not None:
action_values[self._routing_field] = self._routing
actions.append(action_header)
actions.append(action_values)
response = self._client.bulk(actions)
if response["errors"]:
for item in response["items"]:
print(f"{item['index']['status']}: {item['index']['error']['type']}")
action_values: dict[str, Any] = {
Field.CONTENT_KEY.value: documents[i].page_content,
Field.VECTOR.value: embeddings[i],
Field.METADATA_KEY.value: documents[i].metadata,
}
if self._using_ugc:
action_header["index"]["routing"] = self._routing
if self._routing_field is not None:
action_values[self._routing_field] = self._routing

actions.append(action_header)
actions.append(action_values)

logger.info(f"Processing batch {batch_num + 1}/{num_batches} (documents {start_idx + 1} to {end_idx})")

try:
_bulk_with_retry(actions)
logger.info(f"Successfully processed batch {batch_num + 1}")
# simple latency to avoid too many requests in a short time
if batch_num < num_batches - 1:
time.sleep(1)

except Exception:
logger.exception(f"Failed to process batch {batch_num + 1}")
raise

def get_ids_by_metadata_field(self, key: str, value: str):
query: dict[str, Any] = {
@@ -130,7 +175,6 @@ class LindormVectorStore(BaseVector):
if self._using_ugc:
params["routing"] = self._routing
self._client.delete(index=self._collection_name, id=id, params=params)
self.refresh()
else:
logger.warning(f"DELETE BY ID: ID {id} does not exist in the index.")


Ładowanie…
Anuluj
Zapisz