|
|
|
@@ -1,10 +1,9 @@ |
|
|
|
import json |
|
|
|
import logging |
|
|
|
from typing import Any, Optional |
|
|
|
from uuid import uuid4 |
|
|
|
|
|
|
|
from pydantic import BaseModel, model_validator |
|
|
|
from pymilvus import MilvusClient, MilvusException, connections |
|
|
|
from pymilvus import MilvusClient, MilvusException |
|
|
|
from pymilvus.milvus_client import IndexParams |
|
|
|
|
|
|
|
from configs import dify_config |
|
|
|
@@ -21,20 +20,17 @@ logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
|
|
|
class MilvusConfig(BaseModel): |
|
|
|
host: str |
|
|
|
port: int |
|
|
|
uri: str |
|
|
|
token: Optional[str] = None |
|
|
|
user: str |
|
|
|
password: str |
|
|
|
secure: bool = False |
|
|
|
batch_size: int = 100 |
|
|
|
database: str = "default" |
|
|
|
|
|
|
|
@model_validator(mode='before') |
|
|
|
def validate_config(cls, values: dict) -> dict: |
|
|
|
if not values.get('host'): |
|
|
|
raise ValueError("config MILVUS_HOST is required") |
|
|
|
if not values.get('port'): |
|
|
|
raise ValueError("config MILVUS_PORT is required") |
|
|
|
if not values.get('uri'): |
|
|
|
raise ValueError("config MILVUS_URI is required") |
|
|
|
if not values.get('user'): |
|
|
|
raise ValueError("config MILVUS_USER is required") |
|
|
|
if not values.get('password'): |
|
|
|
@@ -43,11 +39,10 @@ class MilvusConfig(BaseModel): |
|
|
|
|
|
|
|
def to_milvus_params(self): |
|
|
|
return { |
|
|
|
'host': self.host, |
|
|
|
'port': self.port, |
|
|
|
'uri': self.uri, |
|
|
|
'token': self.token, |
|
|
|
'user': self.user, |
|
|
|
'password': self.password, |
|
|
|
'secure': self.secure, |
|
|
|
'db_name': self.database, |
|
|
|
} |
|
|
|
|
|
|
|
@@ -111,32 +106,14 @@ class MilvusVector(BaseVector): |
|
|
|
return None |
|
|
|
|
|
|
|
def delete_by_metadata_field(self, key: str, value: str): |
|
|
|
alias = uuid4().hex |
|
|
|
if self._client_config.secure: |
|
|
|
uri = "https://" + str(self._client_config.host) + ":" + str(self._client_config.port) |
|
|
|
else: |
|
|
|
uri = "http://" + str(self._client_config.host) + ":" + str(self._client_config.port) |
|
|
|
connections.connect(alias=alias, uri=uri, user=self._client_config.user, password=self._client_config.password, |
|
|
|
db_name=self._client_config.database) |
|
|
|
|
|
|
|
from pymilvus import utility |
|
|
|
if utility.has_collection(self._collection_name, using=alias): |
|
|
|
if self._client.has_collection(self._collection_name): |
|
|
|
|
|
|
|
ids = self.get_ids_by_metadata_field(key, value) |
|
|
|
if ids: |
|
|
|
self._client.delete(collection_name=self._collection_name, pks=ids) |
|
|
|
|
|
|
|
def delete_by_ids(self, ids: list[str]) -> None: |
|
|
|
alias = uuid4().hex |
|
|
|
if self._client_config.secure: |
|
|
|
uri = "https://" + str(self._client_config.host) + ":" + str(self._client_config.port) |
|
|
|
else: |
|
|
|
uri = "http://" + str(self._client_config.host) + ":" + str(self._client_config.port) |
|
|
|
connections.connect(alias=alias, uri=uri, user=self._client_config.user, password=self._client_config.password, |
|
|
|
db_name=self._client_config.database) |
|
|
|
|
|
|
|
from pymilvus import utility |
|
|
|
if utility.has_collection(self._collection_name, using=alias): |
|
|
|
if self._client.has_collection(self._collection_name): |
|
|
|
|
|
|
|
result = self._client.query(collection_name=self._collection_name, |
|
|
|
filter=f'metadata["doc_id"] in {ids}', |
|
|
|
@@ -146,29 +123,11 @@ class MilvusVector(BaseVector): |
|
|
|
self._client.delete(collection_name=self._collection_name, pks=ids) |
|
|
|
|
|
|
|
def delete(self) -> None: |
|
|
|
alias = uuid4().hex |
|
|
|
if self._client_config.secure: |
|
|
|
uri = "https://" + str(self._client_config.host) + ":" + str(self._client_config.port) |
|
|
|
else: |
|
|
|
uri = "http://" + str(self._client_config.host) + ":" + str(self._client_config.port) |
|
|
|
connections.connect(alias=alias, uri=uri, user=self._client_config.user, password=self._client_config.password, |
|
|
|
db_name=self._client_config.database) |
|
|
|
|
|
|
|
from pymilvus import utility |
|
|
|
if utility.has_collection(self._collection_name, using=alias): |
|
|
|
utility.drop_collection(self._collection_name, None, using=alias) |
|
|
|
if self._client.has_collection(self._collection_name): |
|
|
|
self._client.drop_collection(self._collection_name, None) |
|
|
|
|
|
|
|
def text_exists(self, id: str) -> bool: |
|
|
|
alias = uuid4().hex |
|
|
|
if self._client_config.secure: |
|
|
|
uri = "https://" + str(self._client_config.host) + ":" + str(self._client_config.port) |
|
|
|
else: |
|
|
|
uri = "http://" + str(self._client_config.host) + ":" + str(self._client_config.port) |
|
|
|
connections.connect(alias=alias, uri=uri, user=self._client_config.user, password=self._client_config.password, |
|
|
|
db_name=self._client_config.database) |
|
|
|
|
|
|
|
from pymilvus import utility |
|
|
|
if not utility.has_collection(self._collection_name, using=alias): |
|
|
|
if not self._client.has_collection(self._collection_name): |
|
|
|
return False |
|
|
|
|
|
|
|
result = self._client.query(collection_name=self._collection_name, |
|
|
|
@@ -210,15 +169,7 @@ class MilvusVector(BaseVector): |
|
|
|
if redis_client.get(collection_exist_cache_key): |
|
|
|
return |
|
|
|
# Grab the existing collection if it exists |
|
|
|
from pymilvus import utility |
|
|
|
alias = uuid4().hex |
|
|
|
if self._client_config.secure: |
|
|
|
uri = "https://" + str(self._client_config.host) + ":" + str(self._client_config.port) |
|
|
|
else: |
|
|
|
uri = "http://" + str(self._client_config.host) + ":" + str(self._client_config.port) |
|
|
|
connections.connect(alias=alias, uri=uri, user=self._client_config.user, |
|
|
|
password=self._client_config.password, db_name=self._client_config.database) |
|
|
|
if not utility.has_collection(self._collection_name, using=alias): |
|
|
|
if not self._client.has_collection(self._collection_name): |
|
|
|
from pymilvus import CollectionSchema, DataType, FieldSchema |
|
|
|
from pymilvus.orm.types import infer_dtype_bydata |
|
|
|
|
|
|
|
@@ -263,11 +214,7 @@ class MilvusVector(BaseVector): |
|
|
|
redis_client.set(collection_exist_cache_key, 1, ex=3600) |
|
|
|
|
|
|
|
def _init_client(self, config) -> MilvusClient: |
|
|
|
if config.secure: |
|
|
|
uri = "https://" + str(config.host) + ":" + str(config.port) |
|
|
|
else: |
|
|
|
uri = "http://" + str(config.host) + ":" + str(config.port) |
|
|
|
client = MilvusClient(uri=uri, user=config.user, password=config.password, db_name=config.database) |
|
|
|
client = MilvusClient(uri=config.uri, user=config.user, password=config.password, db_name=config.database) |
|
|
|
return client |
|
|
|
|
|
|
|
|
|
|
|
@@ -285,11 +232,10 @@ class MilvusVectorFactory(AbstractVectorFactory): |
|
|
|
return MilvusVector( |
|
|
|
collection_name=collection_name, |
|
|
|
config=MilvusConfig( |
|
|
|
host=dify_config.MILVUS_HOST, |
|
|
|
port=dify_config.MILVUS_PORT, |
|
|
|
uri=dify_config.MILVUS_URI, |
|
|
|
token=dify_config.MILVUS_TOKEN, |
|
|
|
user=dify_config.MILVUS_USER, |
|
|
|
password=dify_config.MILVUS_PASSWORD, |
|
|
|
secure=dify_config.MILVUS_SECURE, |
|
|
|
database=dify_config.MILVUS_DATABASE, |
|
|
|
) |
|
|
|
) |