Co-authored-by: jyong <jyong@dify.ai>tags/0.5.11-fix1
| @@ -8,6 +8,7 @@ from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaK | |||
| from core.rag.datasource.keyword.keyword_base import BaseKeyword | |||
| from core.rag.models.document import Document | |||
| from extensions.ext_database import db | |||
| from extensions.ext_redis import redis_client | |||
| from models.dataset import Dataset, DatasetKeywordTable, DocumentSegment | |||
| @@ -121,26 +122,28 @@ class Jieba(BaseKeyword): | |||
| db.session.commit() | |||
| def _get_dataset_keyword_table(self) -> Optional[dict]: | |||
| dataset_keyword_table = self.dataset.dataset_keyword_table | |||
| if dataset_keyword_table: | |||
| if dataset_keyword_table.keyword_table_dict: | |||
| return dataset_keyword_table.keyword_table_dict['__data__']['table'] | |||
| else: | |||
| dataset_keyword_table = DatasetKeywordTable( | |||
| dataset_id=self.dataset.id, | |||
| keyword_table=json.dumps({ | |||
| '__type__': 'keyword_table', | |||
| '__data__': { | |||
| "index_id": self.dataset.id, | |||
| "summary": None, | |||
| "table": {} | |||
| } | |||
| }, cls=SetEncoder) | |||
| ) | |||
| db.session.add(dataset_keyword_table) | |||
| db.session.commit() | |||
| lock_name = 'keyword_indexing_lock_{}'.format(self.dataset.id) | |||
| with redis_client.lock(lock_name, timeout=20): | |||
| dataset_keyword_table = self.dataset.dataset_keyword_table | |||
| if dataset_keyword_table: | |||
| if dataset_keyword_table.keyword_table_dict: | |||
| return dataset_keyword_table.keyword_table_dict['__data__']['table'] | |||
| else: | |||
| dataset_keyword_table = DatasetKeywordTable( | |||
| dataset_id=self.dataset.id, | |||
| keyword_table=json.dumps({ | |||
| '__type__': 'keyword_table', | |||
| '__data__': { | |||
| "index_id": self.dataset.id, | |||
| "summary": None, | |||
| "table": {} | |||
| } | |||
| }, cls=SetEncoder) | |||
| ) | |||
| db.session.add(dataset_keyword_table) | |||
| db.session.commit() | |||
| return {} | |||
| return {} | |||
| def _add_text_to_keyword_table(self, keyword_table: dict, id: str, keywords: list[str]) -> dict: | |||
| for keyword in keywords: | |||
| @@ -8,6 +8,7 @@ from pymilvus import MilvusClient, MilvusException, connections | |||
| from core.rag.datasource.vdb.field import Field | |||
| from core.rag.datasource.vdb.vector_base import BaseVector | |||
| from core.rag.models.document import Document | |||
| from extensions.ext_redis import redis_client | |||
| logger = logging.getLogger(__name__) | |||
| @@ -61,17 +62,7 @@ class MilvusVector(BaseVector): | |||
| 'params': {"M": 8, "efConstruction": 64} | |||
| } | |||
| metadatas = [d.metadata for d in texts] | |||
| # Grab the existing collection if it exists | |||
| from pymilvus import utility | |||
| alias = uuid4().hex | |||
| if self._client_config.secure: | |||
| uri = "https://" + str(self._client_config.host) + ":" + str(self._client_config.port) | |||
| else: | |||
| uri = "http://" + str(self._client_config.host) + ":" + str(self._client_config.port) | |||
| connections.connect(alias=alias, uri=uri, user=self._client_config.user, password=self._client_config.password) | |||
| if not utility.has_collection(self._collection_name, using=alias): | |||
| self.create_collection(embeddings, metadatas, index_params) | |||
| self.create_collection(embeddings, metadatas, index_params) | |||
| self.add_texts(texts, embeddings) | |||
| def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): | |||
| @@ -187,46 +178,60 @@ class MilvusVector(BaseVector): | |||
| def create_collection( | |||
| self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None | |||
| ) -> str: | |||
| from pymilvus import CollectionSchema, DataType, FieldSchema | |||
| from pymilvus.orm.types import infer_dtype_bydata | |||
| # Determine embedding dim | |||
| dim = len(embeddings[0]) | |||
| fields = [] | |||
| if metadatas: | |||
| fields.append(FieldSchema(Field.METADATA_KEY.value, DataType.JSON, max_length=65_535)) | |||
| # Create the text field | |||
| fields.append( | |||
| FieldSchema(Field.CONTENT_KEY.value, DataType.VARCHAR, max_length=65_535) | |||
| ) | |||
| # Create the primary key field | |||
| fields.append( | |||
| FieldSchema( | |||
| Field.PRIMARY_KEY.value, DataType.INT64, is_primary=True, auto_id=True | |||
| ) | |||
| ) | |||
| # Create the vector field, supports binary or float vectors | |||
| fields.append( | |||
| FieldSchema(Field.VECTOR.value, infer_dtype_bydata(embeddings[0]), dim=dim) | |||
| ) | |||
| # Create the schema for the collection | |||
| schema = CollectionSchema(fields) | |||
| for x in schema.fields: | |||
| self._fields.append(x.name) | |||
| # Since primary field is auto-id, no need to track it | |||
| self._fields.remove(Field.PRIMARY_KEY.value) | |||
| # Create the collection | |||
| collection_name = self._collection_name | |||
| self._client.create_collection_with_schema(collection_name=collection_name, | |||
| schema=schema, index_param=index_params, | |||
| consistency_level=self._consistency_level) | |||
| return collection_name | |||
| ): | |||
| lock_name = 'vector_indexing_lock_{}'.format(self._collection_name) | |||
| with redis_client.lock(lock_name, timeout=20): | |||
| collection_exist_cache_key = 'vector_indexing_{}'.format(self._collection_name) | |||
| if redis_client.get(collection_exist_cache_key): | |||
| return | |||
| # Grab the existing collection if it exists | |||
| from pymilvus import utility | |||
| alias = uuid4().hex | |||
| if self._client_config.secure: | |||
| uri = "https://" + str(self._client_config.host) + ":" + str(self._client_config.port) | |||
| else: | |||
| uri = "http://" + str(self._client_config.host) + ":" + str(self._client_config.port) | |||
| connections.connect(alias=alias, uri=uri, user=self._client_config.user, | |||
| password=self._client_config.password) | |||
| if not utility.has_collection(self._collection_name, using=alias): | |||
| from pymilvus import CollectionSchema, DataType, FieldSchema | |||
| from pymilvus.orm.types import infer_dtype_bydata | |||
| # Determine embedding dim | |||
| dim = len(embeddings[0]) | |||
| fields = [] | |||
| if metadatas: | |||
| fields.append(FieldSchema(Field.METADATA_KEY.value, DataType.JSON, max_length=65_535)) | |||
| # Create the text field | |||
| fields.append( | |||
| FieldSchema(Field.CONTENT_KEY.value, DataType.VARCHAR, max_length=65_535) | |||
| ) | |||
| # Create the primary key field | |||
| fields.append( | |||
| FieldSchema( | |||
| Field.PRIMARY_KEY.value, DataType.INT64, is_primary=True, auto_id=True | |||
| ) | |||
| ) | |||
| # Create the vector field, supports binary or float vectors | |||
| fields.append( | |||
| FieldSchema(Field.VECTOR.value, infer_dtype_bydata(embeddings[0]), dim=dim) | |||
| ) | |||
| # Create the schema for the collection | |||
| schema = CollectionSchema(fields) | |||
| for x in schema.fields: | |||
| self._fields.append(x.name) | |||
| # Since primary field is auto-id, no need to track it | |||
| self._fields.remove(Field.PRIMARY_KEY.value) | |||
| # Create the collection | |||
| collection_name = self._collection_name | |||
| self._client.create_collection_with_schema(collection_name=collection_name, | |||
| schema=schema, index_param=index_params, | |||
| consistency_level=self._consistency_level) | |||
| redis_client.set(collection_exist_cache_key, 1, ex=3600) | |||
| def _init_client(self, config) -> MilvusClient: | |||
| if config.secure: | |||
| uri = "https://" + str(config.host) + ":" + str(config.port) | |||
| @@ -20,6 +20,7 @@ from qdrant_client.local.qdrant_local import QdrantLocal | |||
| from core.rag.datasource.vdb.field import Field | |||
| from core.rag.datasource.vdb.vector_base import BaseVector | |||
| from core.rag.models.document import Document | |||
| from extensions.ext_redis import redis_client | |||
| if TYPE_CHECKING: | |||
| from qdrant_client import grpc # noqa | |||
| @@ -77,6 +78,17 @@ class QdrantVector(BaseVector): | |||
| vector_size = len(embeddings[0]) | |||
| # get collection name | |||
| collection_name = self._collection_name | |||
| # create collection | |||
| self.create_collection(collection_name, vector_size) | |||
| self.add_texts(texts, embeddings, **kwargs) | |||
| def create_collection(self, collection_name: str, vector_size: int): | |||
| lock_name = 'vector_indexing_lock_{}'.format(collection_name) | |||
| with redis_client.lock(lock_name, timeout=20): | |||
| collection_exist_cache_key = 'vector_indexing_{}'.format(self._collection_name) | |||
| if redis_client.get(collection_exist_cache_key): | |||
| return | |||
| collection_name = collection_name or uuid.uuid4().hex | |||
| all_collection_name = [] | |||
| collections_response = self._client.get_collections() | |||
| @@ -84,40 +96,35 @@ class QdrantVector(BaseVector): | |||
| for collection in collection_list: | |||
| all_collection_name.append(collection.name) | |||
| if collection_name not in all_collection_name: | |||
| # create collection | |||
| self.create_collection(collection_name, vector_size) | |||
| self.add_texts(texts, embeddings, **kwargs) | |||
| def create_collection(self, collection_name: str, vector_size: int): | |||
| from qdrant_client.http import models as rest | |||
| vectors_config = rest.VectorParams( | |||
| size=vector_size, | |||
| distance=rest.Distance[self._distance_func], | |||
| ) | |||
| hnsw_config = HnswConfigDiff(m=0, payload_m=16, ef_construct=100, full_scan_threshold=10000, | |||
| max_indexing_threads=0, on_disk=False) | |||
| self._client.recreate_collection( | |||
| collection_name=collection_name, | |||
| vectors_config=vectors_config, | |||
| hnsw_config=hnsw_config, | |||
| timeout=int(self._client_config.timeout), | |||
| ) | |||
| from qdrant_client.http import models as rest | |||
| vectors_config = rest.VectorParams( | |||
| size=vector_size, | |||
| distance=rest.Distance[self._distance_func], | |||
| ) | |||
| hnsw_config = HnswConfigDiff(m=0, payload_m=16, ef_construct=100, full_scan_threshold=10000, | |||
| max_indexing_threads=0, on_disk=False) | |||
| self._client.recreate_collection( | |||
| collection_name=collection_name, | |||
| vectors_config=vectors_config, | |||
| hnsw_config=hnsw_config, | |||
| timeout=int(self._client_config.timeout), | |||
| ) | |||
| # create payload index | |||
| self._client.create_payload_index(collection_name, Field.GROUP_KEY.value, | |||
| field_schema=PayloadSchemaType.KEYWORD, | |||
| field_type=PayloadSchemaType.KEYWORD) | |||
| # creat full text index | |||
| text_index_params = TextIndexParams( | |||
| type=TextIndexType.TEXT, | |||
| tokenizer=TokenizerType.MULTILINGUAL, | |||
| min_token_len=2, | |||
| max_token_len=20, | |||
| lowercase=True | |||
| ) | |||
| self._client.create_payload_index(collection_name, Field.CONTENT_KEY.value, | |||
| field_schema=text_index_params) | |||
| # create payload index | |||
| self._client.create_payload_index(collection_name, Field.GROUP_KEY.value, | |||
| field_schema=PayloadSchemaType.KEYWORD, | |||
| field_type=PayloadSchemaType.KEYWORD) | |||
| # creat full text index | |||
| text_index_params = TextIndexParams( | |||
| type=TextIndexType.TEXT, | |||
| tokenizer=TokenizerType.MULTILINGUAL, | |||
| min_token_len=2, | |||
| max_token_len=20, | |||
| lowercase=True | |||
| ) | |||
| self._client.create_payload_index(collection_name, Field.CONTENT_KEY.value, | |||
| field_schema=text_index_params) | |||
| redis_client.set(collection_exist_cache_key, 1, ex=3600) | |||
| def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): | |||
| uuids = self._get_uuids(documents) | |||
| @@ -8,6 +8,7 @@ from pydantic import BaseModel, root_validator | |||
| from core.rag.datasource.vdb.field import Field | |||
| from core.rag.datasource.vdb.vector_base import BaseVector | |||
| from core.rag.models.document import Document | |||
| from extensions.ext_redis import redis_client | |||
| from models.dataset import Dataset | |||
| @@ -79,16 +80,23 @@ class WeaviateVector(BaseVector): | |||
| } | |||
| def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): | |||
| schema = self._default_schema(self._collection_name) | |||
| # check whether the index already exists | |||
| if not self._client.schema.contains(schema): | |||
| # create collection | |||
| self._client.schema.create_class(schema) | |||
| # create collection | |||
| self._create_collection() | |||
| # create vector | |||
| self.add_texts(texts, embeddings) | |||
| def _create_collection(self): | |||
| lock_name = 'vector_indexing_lock_{}'.format(self._collection_name) | |||
| with redis_client.lock(lock_name, timeout=20): | |||
| collection_exist_cache_key = 'vector_indexing_{}'.format(self._collection_name) | |||
| if redis_client.get(collection_exist_cache_key): | |||
| return | |||
| schema = self._default_schema(self._collection_name) | |||
| if not self._client.schema.contains(schema): | |||
| # create collection | |||
| self._client.schema.create_class(schema) | |||
| redis_client.set(collection_exist_cache_key, 1, ex=3600) | |||
| def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): | |||
| uuids = self._get_uuids(documents) | |||
| texts = [d.page_content for d in documents] | |||