|
|
|
@@ -20,16 +20,17 @@ class MilvusConfig(BaseModel): |
|
|
|
password: str |
|
|
|
secure: bool = False |
|
|
|
batch_size: int = 100 |
|
|
|
database: str = "default" |
|
|
|
|
|
|
|
@root_validator() |
|
|
|
def validate_config(cls, values: dict) -> dict: |
|
|
|
if not values['host']: |
|
|
|
if not values.get('host'): |
|
|
|
raise ValueError("config MILVUS_HOST is required") |
|
|
|
if not values['port']: |
|
|
|
if not values.get('port'): |
|
|
|
raise ValueError("config MILVUS_PORT is required") |
|
|
|
if not values['user']: |
|
|
|
if not values.get('user'): |
|
|
|
raise ValueError("config MILVUS_USER is required") |
|
|
|
if not values['password']: |
|
|
|
if not values.get('password'): |
|
|
|
raise ValueError("config MILVUS_PASSWORD is required") |
|
|
|
return values |
|
|
|
|
|
|
|
@@ -39,7 +40,8 @@ class MilvusConfig(BaseModel): |
|
|
|
'port': self.port, |
|
|
|
'user': self.user, |
|
|
|
'password': self.password, |
|
|
|
'secure': self.secure |
|
|
|
'secure': self.secure, |
|
|
|
'db_name': self.database, |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
@@ -192,7 +194,7 @@ class MilvusVector(BaseVector): |
|
|
|
else: |
|
|
|
uri = "http://" + str(self._client_config.host) + ":" + str(self._client_config.port) |
|
|
|
connections.connect(alias=alias, uri=uri, user=self._client_config.user, |
|
|
|
password=self._client_config.password) |
|
|
|
password=self._client_config.password, db_name=self._client_config.database) |
|
|
|
if not utility.has_collection(self._collection_name, using=alias): |
|
|
|
from pymilvus import CollectionSchema, DataType, FieldSchema |
|
|
|
from pymilvus.orm.types import infer_dtype_bydata |