Co-authored-by: jyong <jyong@dify.ai>tags/0.5.11-fix1
| from core.rag.datasource.keyword.keyword_base import BaseKeyword | from core.rag.datasource.keyword.keyword_base import BaseKeyword | ||||
| from core.rag.models.document import Document | from core.rag.models.document import Document | ||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| from extensions.ext_redis import redis_client | |||||
| from models.dataset import Dataset, DatasetKeywordTable, DocumentSegment | from models.dataset import Dataset, DatasetKeywordTable, DocumentSegment | ||||
| db.session.commit() | db.session.commit() | ||||
| def _get_dataset_keyword_table(self) -> Optional[dict]: | def _get_dataset_keyword_table(self) -> Optional[dict]: | ||||
| dataset_keyword_table = self.dataset.dataset_keyword_table | |||||
| if dataset_keyword_table: | |||||
| if dataset_keyword_table.keyword_table_dict: | |||||
| return dataset_keyword_table.keyword_table_dict['__data__']['table'] | |||||
| else: | |||||
| dataset_keyword_table = DatasetKeywordTable( | |||||
| dataset_id=self.dataset.id, | |||||
| keyword_table=json.dumps({ | |||||
| '__type__': 'keyword_table', | |||||
| '__data__': { | |||||
| "index_id": self.dataset.id, | |||||
| "summary": None, | |||||
| "table": {} | |||||
| } | |||||
| }, cls=SetEncoder) | |||||
| ) | |||||
| db.session.add(dataset_keyword_table) | |||||
| db.session.commit() | |||||
| lock_name = 'keyword_indexing_lock_{}'.format(self.dataset.id) | |||||
| with redis_client.lock(lock_name, timeout=20): | |||||
| dataset_keyword_table = self.dataset.dataset_keyword_table | |||||
| if dataset_keyword_table: | |||||
| if dataset_keyword_table.keyword_table_dict: | |||||
| return dataset_keyword_table.keyword_table_dict['__data__']['table'] | |||||
| else: | |||||
| dataset_keyword_table = DatasetKeywordTable( | |||||
| dataset_id=self.dataset.id, | |||||
| keyword_table=json.dumps({ | |||||
| '__type__': 'keyword_table', | |||||
| '__data__': { | |||||
| "index_id": self.dataset.id, | |||||
| "summary": None, | |||||
| "table": {} | |||||
| } | |||||
| }, cls=SetEncoder) | |||||
| ) | |||||
| db.session.add(dataset_keyword_table) | |||||
| db.session.commit() | |||||
| return {} | |||||
| return {} | |||||
| def _add_text_to_keyword_table(self, keyword_table: dict, id: str, keywords: list[str]) -> dict: | def _add_text_to_keyword_table(self, keyword_table: dict, id: str, keywords: list[str]) -> dict: | ||||
| for keyword in keywords: | for keyword in keywords: |
| from core.rag.datasource.vdb.field import Field | from core.rag.datasource.vdb.field import Field | ||||
| from core.rag.datasource.vdb.vector_base import BaseVector | from core.rag.datasource.vdb.vector_base import BaseVector | ||||
| from core.rag.models.document import Document | from core.rag.models.document import Document | ||||
| from extensions.ext_redis import redis_client | |||||
| logger = logging.getLogger(__name__) | logger = logging.getLogger(__name__) | ||||
| 'params': {"M": 8, "efConstruction": 64} | 'params': {"M": 8, "efConstruction": 64} | ||||
| } | } | ||||
| metadatas = [d.metadata for d in texts] | metadatas = [d.metadata for d in texts] | ||||
| # Grab the existing collection if it exists | |||||
| from pymilvus import utility | |||||
| alias = uuid4().hex | |||||
| if self._client_config.secure: | |||||
| uri = "https://" + str(self._client_config.host) + ":" + str(self._client_config.port) | |||||
| else: | |||||
| uri = "http://" + str(self._client_config.host) + ":" + str(self._client_config.port) | |||||
| connections.connect(alias=alias, uri=uri, user=self._client_config.user, password=self._client_config.password) | |||||
| if not utility.has_collection(self._collection_name, using=alias): | |||||
| self.create_collection(embeddings, metadatas, index_params) | |||||
| self.create_collection(embeddings, metadatas, index_params) | |||||
| self.add_texts(texts, embeddings) | self.add_texts(texts, embeddings) | ||||
| def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): | def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): | ||||
| def create_collection( | def create_collection( | ||||
| self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None | self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None | ||||
| ) -> str: | |||||
| from pymilvus import CollectionSchema, DataType, FieldSchema | |||||
| from pymilvus.orm.types import infer_dtype_bydata | |||||
| # Determine embedding dim | |||||
| dim = len(embeddings[0]) | |||||
| fields = [] | |||||
| if metadatas: | |||||
| fields.append(FieldSchema(Field.METADATA_KEY.value, DataType.JSON, max_length=65_535)) | |||||
| # Create the text field | |||||
| fields.append( | |||||
| FieldSchema(Field.CONTENT_KEY.value, DataType.VARCHAR, max_length=65_535) | |||||
| ) | |||||
| # Create the primary key field | |||||
| fields.append( | |||||
| FieldSchema( | |||||
| Field.PRIMARY_KEY.value, DataType.INT64, is_primary=True, auto_id=True | |||||
| ) | |||||
| ) | |||||
| # Create the vector field, supports binary or float vectors | |||||
| fields.append( | |||||
| FieldSchema(Field.VECTOR.value, infer_dtype_bydata(embeddings[0]), dim=dim) | |||||
| ) | |||||
| # Create the schema for the collection | |||||
| schema = CollectionSchema(fields) | |||||
| for x in schema.fields: | |||||
| self._fields.append(x.name) | |||||
| # Since primary field is auto-id, no need to track it | |||||
| self._fields.remove(Field.PRIMARY_KEY.value) | |||||
| # Create the collection | |||||
| collection_name = self._collection_name | |||||
| self._client.create_collection_with_schema(collection_name=collection_name, | |||||
| schema=schema, index_param=index_params, | |||||
| consistency_level=self._consistency_level) | |||||
| return collection_name | |||||
| ): | |||||
| lock_name = 'vector_indexing_lock_{}'.format(self._collection_name) | |||||
| with redis_client.lock(lock_name, timeout=20): | |||||
| collection_exist_cache_key = 'vector_indexing_{}'.format(self._collection_name) | |||||
| if redis_client.get(collection_exist_cache_key): | |||||
| return | |||||
| # Grab the existing collection if it exists | |||||
| from pymilvus import utility | |||||
| alias = uuid4().hex | |||||
| if self._client_config.secure: | |||||
| uri = "https://" + str(self._client_config.host) + ":" + str(self._client_config.port) | |||||
| else: | |||||
| uri = "http://" + str(self._client_config.host) + ":" + str(self._client_config.port) | |||||
| connections.connect(alias=alias, uri=uri, user=self._client_config.user, | |||||
| password=self._client_config.password) | |||||
| if not utility.has_collection(self._collection_name, using=alias): | |||||
| from pymilvus import CollectionSchema, DataType, FieldSchema | |||||
| from pymilvus.orm.types import infer_dtype_bydata | |||||
| # Determine embedding dim | |||||
| dim = len(embeddings[0]) | |||||
| fields = [] | |||||
| if metadatas: | |||||
| fields.append(FieldSchema(Field.METADATA_KEY.value, DataType.JSON, max_length=65_535)) | |||||
| # Create the text field | |||||
| fields.append( | |||||
| FieldSchema(Field.CONTENT_KEY.value, DataType.VARCHAR, max_length=65_535) | |||||
| ) | |||||
| # Create the primary key field | |||||
| fields.append( | |||||
| FieldSchema( | |||||
| Field.PRIMARY_KEY.value, DataType.INT64, is_primary=True, auto_id=True | |||||
| ) | |||||
| ) | |||||
| # Create the vector field, supports binary or float vectors | |||||
| fields.append( | |||||
| FieldSchema(Field.VECTOR.value, infer_dtype_bydata(embeddings[0]), dim=dim) | |||||
| ) | |||||
| # Create the schema for the collection | |||||
| schema = CollectionSchema(fields) | |||||
| for x in schema.fields: | |||||
| self._fields.append(x.name) | |||||
| # Since primary field is auto-id, no need to track it | |||||
| self._fields.remove(Field.PRIMARY_KEY.value) | |||||
| # Create the collection | |||||
| collection_name = self._collection_name | |||||
| self._client.create_collection_with_schema(collection_name=collection_name, | |||||
| schema=schema, index_param=index_params, | |||||
| consistency_level=self._consistency_level) | |||||
| redis_client.set(collection_exist_cache_key, 1, ex=3600) | |||||
| def _init_client(self, config) -> MilvusClient: | def _init_client(self, config) -> MilvusClient: | ||||
| if config.secure: | if config.secure: | ||||
| uri = "https://" + str(config.host) + ":" + str(config.port) | uri = "https://" + str(config.host) + ":" + str(config.port) |
| from core.rag.datasource.vdb.field import Field | from core.rag.datasource.vdb.field import Field | ||||
| from core.rag.datasource.vdb.vector_base import BaseVector | from core.rag.datasource.vdb.vector_base import BaseVector | ||||
| from core.rag.models.document import Document | from core.rag.models.document import Document | ||||
| from extensions.ext_redis import redis_client | |||||
| if TYPE_CHECKING: | if TYPE_CHECKING: | ||||
| from qdrant_client import grpc # noqa | from qdrant_client import grpc # noqa | ||||
| vector_size = len(embeddings[0]) | vector_size = len(embeddings[0]) | ||||
| # get collection name | # get collection name | ||||
| collection_name = self._collection_name | collection_name = self._collection_name | ||||
| # create collection | |||||
| self.create_collection(collection_name, vector_size) | |||||
| self.add_texts(texts, embeddings, **kwargs) | |||||
| def create_collection(self, collection_name: str, vector_size: int): | |||||
| lock_name = 'vector_indexing_lock_{}'.format(collection_name) | |||||
| with redis_client.lock(lock_name, timeout=20): | |||||
| collection_exist_cache_key = 'vector_indexing_{}'.format(self._collection_name) | |||||
| if redis_client.get(collection_exist_cache_key): | |||||
| return | |||||
| collection_name = collection_name or uuid.uuid4().hex | collection_name = collection_name or uuid.uuid4().hex | ||||
| all_collection_name = [] | all_collection_name = [] | ||||
| collections_response = self._client.get_collections() | collections_response = self._client.get_collections() | ||||
| for collection in collection_list: | for collection in collection_list: | ||||
| all_collection_name.append(collection.name) | all_collection_name.append(collection.name) | ||||
| if collection_name not in all_collection_name: | if collection_name not in all_collection_name: | ||||
| # create collection | |||||
| self.create_collection(collection_name, vector_size) | |||||
| self.add_texts(texts, embeddings, **kwargs) | |||||
| def create_collection(self, collection_name: str, vector_size: int): | |||||
| from qdrant_client.http import models as rest | |||||
| vectors_config = rest.VectorParams( | |||||
| size=vector_size, | |||||
| distance=rest.Distance[self._distance_func], | |||||
| ) | |||||
| hnsw_config = HnswConfigDiff(m=0, payload_m=16, ef_construct=100, full_scan_threshold=10000, | |||||
| max_indexing_threads=0, on_disk=False) | |||||
| self._client.recreate_collection( | |||||
| collection_name=collection_name, | |||||
| vectors_config=vectors_config, | |||||
| hnsw_config=hnsw_config, | |||||
| timeout=int(self._client_config.timeout), | |||||
| ) | |||||
| from qdrant_client.http import models as rest | |||||
| vectors_config = rest.VectorParams( | |||||
| size=vector_size, | |||||
| distance=rest.Distance[self._distance_func], | |||||
| ) | |||||
| hnsw_config = HnswConfigDiff(m=0, payload_m=16, ef_construct=100, full_scan_threshold=10000, | |||||
| max_indexing_threads=0, on_disk=False) | |||||
| self._client.recreate_collection( | |||||
| collection_name=collection_name, | |||||
| vectors_config=vectors_config, | |||||
| hnsw_config=hnsw_config, | |||||
| timeout=int(self._client_config.timeout), | |||||
| ) | |||||
| # create payload index | |||||
| self._client.create_payload_index(collection_name, Field.GROUP_KEY.value, | |||||
| field_schema=PayloadSchemaType.KEYWORD, | |||||
| field_type=PayloadSchemaType.KEYWORD) | |||||
| # creat full text index | |||||
| text_index_params = TextIndexParams( | |||||
| type=TextIndexType.TEXT, | |||||
| tokenizer=TokenizerType.MULTILINGUAL, | |||||
| min_token_len=2, | |||||
| max_token_len=20, | |||||
| lowercase=True | |||||
| ) | |||||
| self._client.create_payload_index(collection_name, Field.CONTENT_KEY.value, | |||||
| field_schema=text_index_params) | |||||
| # create payload index | |||||
| self._client.create_payload_index(collection_name, Field.GROUP_KEY.value, | |||||
| field_schema=PayloadSchemaType.KEYWORD, | |||||
| field_type=PayloadSchemaType.KEYWORD) | |||||
| # creat full text index | |||||
| text_index_params = TextIndexParams( | |||||
| type=TextIndexType.TEXT, | |||||
| tokenizer=TokenizerType.MULTILINGUAL, | |||||
| min_token_len=2, | |||||
| max_token_len=20, | |||||
| lowercase=True | |||||
| ) | |||||
| self._client.create_payload_index(collection_name, Field.CONTENT_KEY.value, | |||||
| field_schema=text_index_params) | |||||
| redis_client.set(collection_exist_cache_key, 1, ex=3600) | |||||
| def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): | def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): | ||||
| uuids = self._get_uuids(documents) | uuids = self._get_uuids(documents) |
| from core.rag.datasource.vdb.field import Field | from core.rag.datasource.vdb.field import Field | ||||
| from core.rag.datasource.vdb.vector_base import BaseVector | from core.rag.datasource.vdb.vector_base import BaseVector | ||||
| from core.rag.models.document import Document | from core.rag.models.document import Document | ||||
| from extensions.ext_redis import redis_client | |||||
| from models.dataset import Dataset | from models.dataset import Dataset | ||||
| } | } | ||||
| def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): | def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): | ||||
| schema = self._default_schema(self._collection_name) | |||||
| # check whether the index already exists | |||||
| if not self._client.schema.contains(schema): | |||||
| # create collection | |||||
| self._client.schema.create_class(schema) | |||||
| # create collection | |||||
| self._create_collection() | |||||
| # create vector | # create vector | ||||
| self.add_texts(texts, embeddings) | self.add_texts(texts, embeddings) | ||||
| def _create_collection(self): | |||||
| lock_name = 'vector_indexing_lock_{}'.format(self._collection_name) | |||||
| with redis_client.lock(lock_name, timeout=20): | |||||
| collection_exist_cache_key = 'vector_indexing_{}'.format(self._collection_name) | |||||
| if redis_client.get(collection_exist_cache_key): | |||||
| return | |||||
| schema = self._default_schema(self._collection_name) | |||||
| if not self._client.schema.contains(schema): | |||||
| # create collection | |||||
| self._client.schema.create_class(schema) | |||||
| redis_client.set(collection_exist_cache_key, 1, ex=3600) | |||||
| def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): | def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): | ||||
| uuids = self._get_uuids(documents) | uuids = self._get_uuids(documents) | ||||
| texts = [d.page_content for d in documents] | texts = [d.page_content for d in documents] |