| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136 | 
							- import json
 - import logging
 - from typing import List, Optional
 - 
 - from llama_index.data_structs import Node
 - from requests import ReadTimeout
 - from sqlalchemy.exc import IntegrityError
 - from tenacity import retry, stop_after_attempt, retry_if_exception_type
 - 
 - from core.index.index_builder import IndexBuilder
 - from core.vector_store.base import BaseGPTVectorStoreIndex
 - from extensions.ext_vector_store import vector_store
 - from extensions.ext_database import db
 - from models.dataset import Dataset, Embedding
 - 
 - 
 - class VectorIndex:
 - 
 -     def __init__(self, dataset: Dataset):
 -         self._dataset = dataset
 - 
 -     def add_nodes(self, nodes: List[Node], duplicate_check: bool = False):
 -         if not self._dataset.index_struct_dict:
 -             index_id = "Vector_index_" + self._dataset.id.replace("-", "_")
 -             self._dataset.index_struct = json.dumps(vector_store.to_index_struct(index_id))
 -             db.session.commit()
 - 
 -         service_context = IndexBuilder.get_default_service_context(tenant_id=self._dataset.tenant_id)
 - 
 -         index = vector_store.get_index(
 -             service_context=service_context,
 -             index_struct=self._dataset.index_struct_dict
 -         )
 - 
 -         if duplicate_check:
 -             nodes = self._filter_duplicate_nodes(index, nodes)
 - 
 -         embedding_queue_nodes = []
 -         embedded_nodes = []
 -         for node in nodes:
 -             node_hash = node.doc_hash
 - 
 -             # if node hash in cached embedding tables, use cached embedding
 -             embedding = db.session.query(Embedding).filter_by(hash=node_hash).first()
 -             if embedding:
 -                 node.embedding = embedding.get_embedding()
 -                 embedded_nodes.append(node)
 -             else:
 -                 embedding_queue_nodes.append(node)
 - 
 -         if embedding_queue_nodes:
 -             embedding_results = index._get_node_embedding_results(
 -                 embedding_queue_nodes,
 -                 set(),
 -             )
 - 
 -             # pre embed nodes for cached embedding
 -             for embedding_result in embedding_results:
 -                 node = embedding_result.node
 -                 node.embedding = embedding_result.embedding
 - 
 -                 try:
 -                     embedding = Embedding(hash=node.doc_hash)
 -                     embedding.set_embedding(node.embedding)
 -                     db.session.add(embedding)
 -                     db.session.commit()
 -                 except IntegrityError:
 -                     db.session.rollback()
 -                     continue
 -                 except:
 -                     logging.exception('Failed to add embedding to db')
 -                     continue
 - 
 -                 embedded_nodes.append(node)
 - 
 -         self.index_insert_nodes(index, embedded_nodes)
 - 
 -     @retry(reraise=True, retry=retry_if_exception_type(ReadTimeout), stop=stop_after_attempt(3))
 -     def index_insert_nodes(self, index: BaseGPTVectorStoreIndex, nodes: List[Node]):
 -         index.insert_nodes(nodes)
 - 
 -     def del_nodes(self, node_ids: List[str]):
 -         if not self._dataset.index_struct_dict:
 -             return
 - 
 -         service_context = IndexBuilder.get_fake_llm_service_context(tenant_id=self._dataset.tenant_id)
 - 
 -         index = vector_store.get_index(
 -             service_context=service_context,
 -             index_struct=self._dataset.index_struct_dict
 -         )
 - 
 -         for node_id in node_ids:
 -             self.index_delete_node(index, node_id)
 - 
 -     @retry(reraise=True, retry=retry_if_exception_type(ReadTimeout), stop=stop_after_attempt(3))
 -     def index_delete_node(self, index: BaseGPTVectorStoreIndex, node_id: str):
 -         index.delete_node(node_id)
 - 
 -     def del_doc(self, doc_id: str):
 -         if not self._dataset.index_struct_dict:
 -             return
 - 
 -         service_context = IndexBuilder.get_fake_llm_service_context(tenant_id=self._dataset.tenant_id)
 - 
 -         index = vector_store.get_index(
 -             service_context=service_context,
 -             index_struct=self._dataset.index_struct_dict
 -         )
 - 
 -         self.index_delete_doc(index, doc_id)
 - 
 -     @retry(reraise=True, retry=retry_if_exception_type(ReadTimeout), stop=stop_after_attempt(3))
 -     def index_delete_doc(self, index: BaseGPTVectorStoreIndex, doc_id: str):
 -         index.delete(doc_id)
 - 
 -     @property
 -     def query_index(self) -> Optional[BaseGPTVectorStoreIndex]:
 -         if not self._dataset.index_struct_dict:
 -             return None
 - 
 -         service_context = IndexBuilder.get_default_service_context(tenant_id=self._dataset.tenant_id)
 - 
 -         return vector_store.get_index(
 -             service_context=service_context,
 -             index_struct=self._dataset.index_struct_dict
 -         )
 - 
 -     def _filter_duplicate_nodes(self, index: BaseGPTVectorStoreIndex, nodes: List[Node]) -> List[Node]:
 -         for node in nodes:
 -             node_id = node.doc_id
 -             exists_duplicate_node = index.exists_by_node_id(node_id)
 -             if exists_duplicate_node:
 -                 nodes.remove(node)
 - 
 -         return nodes
 
 
  |