Browse Source

optimize lindorm vdb add_texts (#17212)

Co-authored-by: jiangzhijie <jiangzhijie.jzj@alibaba-inc.com>
tags/1.2.0
Jiang 7 months ago
parent
commit
ff388fe3e6
No account linked to committer's email address
1 changed files with 68 additions and 24 deletions
  1. 68
    24
      api/core/rag/datasource/vdb/lindorm/lindorm_vector.py

+ 68
- 24
api/core/rag/datasource/vdb/lindorm/lindorm_vector.py View File

import copy import copy
import json import json
import logging import logging
import time
from typing import Any, Optional from typing import Any, Optional


from opensearchpy import OpenSearch from opensearchpy import OpenSearch
from pydantic import BaseModel, model_validator from pydantic import BaseModel, model_validator
from tenacity import retry, stop_after_attempt, wait_exponential


from configs import dify_config from configs import dify_config
from core.rag.datasource.vdb.field import Field from core.rag.datasource.vdb.field import Field
def refresh(self): def refresh(self):
self._client.indices.refresh(index=self._collection_name) self._client.indices.refresh(index=self._collection_name)


def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
actions = []
def add_texts(
self,
documents: list[Document],
embeddings: list[list[float]],
batch_size: int = 64,
timeout: int = 60,
**kwargs,
):
logger.info(f"Total documents to add: {len(documents)}")
uuids = self._get_uuids(documents) uuids = self._get_uuids(documents)
for i in range(len(documents)):
action_header = {
"index": {
"_index": self.collection_name.lower(),
"_id": uuids[i],

total_docs = len(documents)
num_batches = (total_docs + batch_size - 1) // batch_size

@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
)
def _bulk_with_retry(actions):
try:
response = self._client.bulk(actions, timeout=timeout)
if response["errors"]:
error_items = [item for item in response["items"] if "error" in item["index"]]
error_msg = f"Bulk indexing had {len(error_items)} errors"
logger.exception(error_msg)
raise Exception(error_msg)
return response
except Exception:
logger.exception("Bulk indexing error")
raise

for batch_num in range(num_batches):
start_idx = batch_num * batch_size
end_idx = min((batch_num + 1) * batch_size, total_docs)

actions = []
for i in range(start_idx, end_idx):
action_header = {
"index": {
"_index": self.collection_name.lower(),
"_id": uuids[i],
}
} }
}
action_values: dict[str, Any] = {
Field.CONTENT_KEY.value: documents[i].page_content,
Field.VECTOR.value: embeddings[i], # Make sure you pass an array here
Field.METADATA_KEY.value: documents[i].metadata,
}
if self._using_ugc:
action_header["index"]["routing"] = self._routing
if self._routing_field is not None:
action_values[self._routing_field] = self._routing
actions.append(action_header)
actions.append(action_values)
response = self._client.bulk(actions)
if response["errors"]:
for item in response["items"]:
print(f"{item['index']['status']}: {item['index']['error']['type']}")
action_values: dict[str, Any] = {
Field.CONTENT_KEY.value: documents[i].page_content,
Field.VECTOR.value: embeddings[i],
Field.METADATA_KEY.value: documents[i].metadata,
}
if self._using_ugc:
action_header["index"]["routing"] = self._routing
if self._routing_field is not None:
action_values[self._routing_field] = self._routing

actions.append(action_header)
actions.append(action_values)

logger.info(f"Processing batch {batch_num + 1}/{num_batches} (documents {start_idx + 1} to {end_idx})")

try:
_bulk_with_retry(actions)
logger.info(f"Successfully processed batch {batch_num + 1}")
# simple latency to avoid too many requests in a short time
if batch_num < num_batches - 1:
time.sleep(1)

except Exception:
logger.exception(f"Failed to process batch {batch_num + 1}")
raise


def get_ids_by_metadata_field(self, key: str, value: str): def get_ids_by_metadata_field(self, key: str, value: str):
query: dict[str, Any] = { query: dict[str, Any] = {
if self._using_ugc: if self._using_ugc:
params["routing"] = self._routing params["routing"] = self._routing
self._client.delete(index=self._collection_name, id=id, params=params) self._client.delete(index=self._collection_name, id=id, params=params)
self.refresh()
else: else:
logger.warning(f"DELETE BY ID: ID {id} does not exist in the index.") logger.warning(f"DELETE BY ID: ID {id} does not exist in the index.")



Loading…
Cancel
Save