瀏覽代碼

Fix/qdrant data issue (#1203)

Co-authored-by: jyong <jyong@dify.ai>
tags/0.3.24
Jyong 2 年之前
父節點
當前提交
724e053732
沒有連結到貢獻者的電子郵件帳戶。

+ 80
- 79
api/commands.py 查看文件

import math import math
import random import random
import string import string
import threading
import time import time
import uuid import uuid


import click import click
from tqdm import tqdm from tqdm import tqdm
from flask import current_app
from flask import current_app, Flask
from langchain.embeddings import OpenAIEmbeddings from langchain.embeddings import OpenAIEmbeddings
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound


@click.command('normalization-collections', help='restore all collections in one') @click.command('normalization-collections', help='restore all collections in one')
def normalization_collections(): def normalization_collections():
click.echo(click.style('Start normalization collections.', fg='green')) click.echo(click.style('Start normalization collections.', fg='green'))
normalization_count = 0

normalization_count = []
page = 1 page = 1
while True: while True:
try: try:
datasets = db.session.query(Dataset).filter(Dataset.indexing_technique == 'high_quality') \ datasets = db.session.query(Dataset).filter(Dataset.indexing_technique == 'high_quality') \
.order_by(Dataset.created_at.desc()).paginate(page=page, per_page=50)
.order_by(Dataset.created_at.desc()).paginate(page=page, per_page=100)
except NotFound: except NotFound:
break break
datasets_result = datasets.items
page += 1 page += 1
for dataset in datasets:
if not dataset.collection_binding_id:
try:
click.echo('restore dataset index: {}'.format(dataset.id))
try:
embedding_model = ModelFactory.get_embedding_model(
tenant_id=dataset.tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
)
except Exception:
provider = Provider(
id='provider_id',
tenant_id=dataset.tenant_id,
provider_name='openai',
provider_type=ProviderType.CUSTOM.value,
encrypted_config=json.dumps({'openai_api_key': 'TEST'}),
is_valid=True,
)
model_provider = OpenAIProvider(provider=provider)
embedding_model = OpenAIEmbedding(name="text-embedding-ada-002",
model_provider=model_provider)
embeddings = CacheEmbedding(embedding_model)
dataset_collection_binding = db.session.query(DatasetCollectionBinding). \
filter(DatasetCollectionBinding.provider_name == embedding_model.model_provider.provider_name,
DatasetCollectionBinding.model_name == embedding_model.name). \
order_by(DatasetCollectionBinding.created_at). \
first()

if not dataset_collection_binding:
dataset_collection_binding = DatasetCollectionBinding(
provider_name=embedding_model.model_provider.provider_name,
model_name=embedding_model.name,
collection_name="Vector_index_" + str(uuid.uuid4()).replace("-", "_") + '_Node'
)
db.session.add(dataset_collection_binding)
db.session.commit()
for i in range(0, len(datasets_result), 5):
threads = []
sub_datasets = datasets_result[i:i + 5]
for dataset in sub_datasets:
document_format_thread = threading.Thread(target=deal_dataset_vector, kwargs={
'flask_app': current_app._get_current_object(),
'dataset': dataset,
'normalization_count': normalization_count
})
threads.append(document_format_thread)
document_format_thread.start()
for thread in threads:
thread.join()

click.echo(click.style('Congratulations! restore {} dataset indexes.'.format(len(normalization_count)), fg='green'))


def deal_dataset_vector(flask_app: Flask, dataset: Dataset, normalization_count: list):
with flask_app.app_context():
try:
click.echo('restore dataset index: {}'.format(dataset.id))
try:
embedding_model = ModelFactory.get_embedding_model(
tenant_id=dataset.tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
)
except Exception:
provider = Provider(
id='provider_id',
tenant_id=dataset.tenant_id,
provider_name='openai',
provider_type=ProviderType.CUSTOM.value,
encrypted_config=json.dumps({'openai_api_key': 'TEST'}),
is_valid=True,
)
model_provider = OpenAIProvider(provider=provider)
embedding_model = OpenAIEmbedding(name="text-embedding-ada-002",
model_provider=model_provider)
embeddings = CacheEmbedding(embedding_model)
dataset_collection_binding = db.session.query(DatasetCollectionBinding). \
filter(DatasetCollectionBinding.provider_name == embedding_model.model_provider.provider_name,
DatasetCollectionBinding.model_name == embedding_model.name). \
order_by(DatasetCollectionBinding.created_at). \
first()

if not dataset_collection_binding:
dataset_collection_binding = DatasetCollectionBinding(
provider_name=embedding_model.model_provider.provider_name,
model_name=embedding_model.name,
collection_name="Vector_index_" + str(uuid.uuid4()).replace("-", "_") + '_Node'
)
db.session.add(dataset_collection_binding)
db.session.commit()


from core.index.vector_index.qdrant_vector_index import QdrantVectorIndex, QdrantConfig

index = QdrantVectorIndex(
dataset=dataset,
config=QdrantConfig(
endpoint=current_app.config.get('QDRANT_URL'),
api_key=current_app.config.get('QDRANT_API_KEY'),
root_path=current_app.root_path
),
embeddings=embeddings
)
if index:
index.restore_dataset_in_one(dataset, dataset_collection_binding)
else:
click.echo('passed.')

original_index = QdrantVectorIndex(
dataset=dataset,
config=QdrantConfig(
endpoint=current_app.config.get('QDRANT_URL'),
api_key=current_app.config.get('QDRANT_API_KEY'),
root_path=current_app.root_path
),
embeddings=embeddings
)
if original_index:
original_index.delete_original_collection(dataset, dataset_collection_binding)
normalization_count += 1
else:
click.echo('passed.')
except Exception as e:
click.echo(
click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)),
fg='red'))
continue

click.echo(click.style('Congratulations! restore {} dataset indexes.'.format(normalization_count), fg='green'))
from core.index.vector_index.qdrant_vector_index import QdrantVectorIndex, QdrantConfig

index = QdrantVectorIndex(
dataset=dataset,
config=QdrantConfig(
endpoint=current_app.config.get('QDRANT_URL'),
api_key=current_app.config.get('QDRANT_API_KEY'),
root_path=current_app.root_path
),
embeddings=embeddings
)
if index:
# index.delete_by_group_id(dataset.id)
index.restore_dataset_in_one(dataset, dataset_collection_binding)
else:
click.echo('passed.')
normalization_count.append(1)
except Exception as e:
click.echo(
click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)),
fg='red'))




@click.command('update_app_model_configs', help='Migrate data to support paragraph variable.') @click.command('update_app_model_configs', help='Migrate data to support paragraph variable.')

+ 5
- 3
api/core/index/vector_index/base.py 查看文件

def delete_by_group_id(self, group_id: str) -> None: def delete_by_group_id(self, group_id: str) -> None:
vector_store = self._get_vector_store() vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store) vector_store = cast(self._get_vector_store_class(), vector_store)

vector_store.delete()
if self.dataset.collection_binding_id:
vector_store.delete_by_group_id(group_id)
else:
vector_store.delete()


def delete(self) -> None: def delete(self) -> None:
vector_store = self._get_vector_store() vector_store = self._get_vector_store()


if documents: if documents:
try: try:
self.create_with_collection_name(documents, dataset_collection_binding.collection_name)
self.add_texts(documents)
except Exception as e: except Exception as e:
raise e raise e



+ 67
- 64
api/core/index/vector_index/qdrant.py 查看文件

path=path, path=path,
**kwargs, **kwargs,
) )
try:
# Skip any validation in case of forced collection recreate.
if force_recreate:
raise ValueError

# Get the vector configuration of the existing collection and vector, if it
# was specified. If the old configuration does not match the current one,
# an exception is being thrown.
collection_info = client.get_collection(collection_name=collection_name)
current_vector_config = collection_info.config.params.vectors
if isinstance(current_vector_config, dict) and vector_name is not None:
if vector_name not in current_vector_config:
raise QdrantException(
f"Existing Qdrant collection {collection_name} does not "
f"contain vector named {vector_name}. Did you mean one of the "
f"existing vectors: {', '.join(current_vector_config.keys())}? "
f"If you want to recreate the collection, set `force_recreate` "
f"parameter to `True`."
)
current_vector_config = current_vector_config.get(
vector_name
) # type: ignore[assignment]
elif isinstance(current_vector_config, dict) and vector_name is None:
raise QdrantException(
f"Existing Qdrant collection {collection_name} uses named vectors. "
f"If you want to reuse it, please set `vector_name` to any of the "
f"existing named vectors: "
f"{', '.join(current_vector_config.keys())}." # noqa
f"If you want to recreate the collection, set `force_recreate` "
f"parameter to `True`."
)
elif (
not isinstance(current_vector_config, dict) and vector_name is not None
):
raise QdrantException(
f"Existing Qdrant collection {collection_name} doesn't use named "
f"vectors. If you want to reuse it, please set `vector_name` to "
f"`None`. If you want to recreate the collection, set "
f"`force_recreate` parameter to `True`."
)

# Check if the vector configuration has the same dimensionality.
if current_vector_config.size != vector_size: # type: ignore[union-attr]
raise QdrantException(
f"Existing Qdrant collection is configured for vectors with "
f"{current_vector_config.size} " # type: ignore[union-attr]
f"dimensions. Selected embeddings are {vector_size}-dimensional. "
f"If you want to recreate the collection, set `force_recreate` "
f"parameter to `True`."
)

current_distance_func = (
current_vector_config.distance.name.upper() # type: ignore[union-attr]
)
if current_distance_func != distance_func:
raise QdrantException(
f"Existing Qdrant collection is configured for "
f"{current_vector_config.distance} " # type: ignore[union-attr]
f"similarity. Please set `distance_func` parameter to "
f"`{distance_func}` if you want to reuse it. If you want to "
f"recreate the collection, set `force_recreate` parameter to "
f"`True`."
)
except (UnexpectedResponse, RpcError, ValueError):
all_collection_name = []
collections_response = client.get_collections()
collection_list = collections_response.collections
for collection in collection_list:
all_collection_name.append(collection.name)
if collection_name not in all_collection_name:
vectors_config = rest.VectorParams( vectors_config = rest.VectorParams(
size=vector_size, size=vector_size,
distance=rest.Distance[distance_func], distance=rest.Distance[distance_func],
timeout=timeout, # type: ignore[arg-type] timeout=timeout, # type: ignore[arg-type]
) )
is_new_collection = True is_new_collection = True
if force_recreate:
raise ValueError

# Get the vector configuration of the existing collection and vector, if it
# was specified. If the old configuration does not match the current one,
# an exception is being thrown.
collection_info = client.get_collection(collection_name=collection_name)
current_vector_config = collection_info.config.params.vectors
if isinstance(current_vector_config, dict) and vector_name is not None:
if vector_name not in current_vector_config:
raise QdrantException(
f"Existing Qdrant collection {collection_name} does not "
f"contain vector named {vector_name}. Did you mean one of the "
f"existing vectors: {', '.join(current_vector_config.keys())}? "
f"If you want to recreate the collection, set `force_recreate` "
f"parameter to `True`."
)
current_vector_config = current_vector_config.get(
vector_name
) # type: ignore[assignment]
elif isinstance(current_vector_config, dict) and vector_name is None:
raise QdrantException(
f"Existing Qdrant collection {collection_name} uses named vectors. "
f"If you want to reuse it, please set `vector_name` to any of the "
f"existing named vectors: "
f"{', '.join(current_vector_config.keys())}." # noqa
f"If you want to recreate the collection, set `force_recreate` "
f"parameter to `True`."
)
elif (
not isinstance(current_vector_config, dict) and vector_name is not None
):
raise QdrantException(
f"Existing Qdrant collection {collection_name} doesn't use named "
f"vectors. If you want to reuse it, please set `vector_name` to "
f"`None`. If you want to recreate the collection, set "
f"`force_recreate` parameter to `True`."
)

# Check if the vector configuration has the same dimensionality.
if current_vector_config.size != vector_size: # type: ignore[union-attr]
raise QdrantException(
f"Existing Qdrant collection is configured for vectors with "
f"{current_vector_config.size} " # type: ignore[union-attr]
f"dimensions. Selected embeddings are {vector_size}-dimensional. "
f"If you want to recreate the collection, set `force_recreate` "
f"parameter to `True`."
)

current_distance_func = (
current_vector_config.distance.name.upper() # type: ignore[union-attr]
)
if current_distance_func != distance_func:
raise QdrantException(
f"Existing Qdrant collection is configured for "
f"{current_vector_config.distance} " # type: ignore[union-attr]
f"similarity. Please set `distance_func` parameter to "
f"`{distance_func}` if you want to reuse it. If you want to "
f"recreate the collection, set `force_recreate` parameter to "
f"`True`."
)
qdrant = cls( qdrant = cls(
client=client, client=client,
collection_name=collection_name, collection_name=collection_name,

+ 13
- 0
api/core/index/vector_index/qdrant_vector_index.py 查看文件

], ],
)) ))


def delete(self) -> None:
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)

from qdrant_client.http import models
vector_store.del_texts(models.Filter(
must=[
models.FieldCondition(
key="group_id",
match=models.MatchValue(value=self.dataset.id),
),
],
))


def _is_origin(self): def _is_origin(self):
if self.dataset.index_struct_dict: if self.dataset.index_struct_dict:

+ 2
- 1
api/events/event_handlers/clean_when_dataset_deleted.py 查看文件

@dataset_was_deleted.connect @dataset_was_deleted.connect
def handle(sender, **kwargs): def handle(sender, **kwargs):
dataset = sender dataset = sender
clean_dataset_task.delay(dataset.id, dataset.tenant_id, dataset.indexing_technique, dataset.index_struct)
clean_dataset_task.delay(dataset.id, dataset.tenant_id, dataset.indexing_technique,
dataset.index_struct, dataset.collection_binding_id)

+ 6
- 4
api/tasks/clean_dataset_task.py 查看文件





@shared_task(queue='dataset') @shared_task(queue='dataset')
def clean_dataset_task(dataset_id: str, tenant_id: str, indexing_technique: str, index_struct: str):
def clean_dataset_task(dataset_id: str, tenant_id: str, indexing_technique: str,
index_struct: str, collection_binding_id: str):
""" """
Clean dataset when dataset deleted. Clean dataset when dataset deleted.
:param dataset_id: dataset id :param dataset_id: dataset id
:param tenant_id: tenant id :param tenant_id: tenant id
:param indexing_technique: indexing technique :param indexing_technique: indexing technique
:param index_struct: index struct dict :param index_struct: index struct dict
:param collection_binding_id: collection binding id


Usage: clean_dataset_task.delay(dataset_id, tenant_id, indexing_technique, index_struct) Usage: clean_dataset_task.delay(dataset_id, tenant_id, indexing_technique, index_struct)
""" """
id=dataset_id, id=dataset_id,
tenant_id=tenant_id, tenant_id=tenant_id,
indexing_technique=indexing_technique, indexing_technique=indexing_technique,
index_struct=index_struct
index_struct=index_struct,
collection_binding_id=collection_binding_id
) )

documents = db.session.query(Document).filter(Document.dataset_id == dataset_id).all() documents = db.session.query(Document).filter(Document.dataset_id == dataset_id).all()
segments = db.session.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset_id).all() segments = db.session.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset_id).all()


if dataset.indexing_technique == 'high_quality': if dataset.indexing_technique == 'high_quality':
vector_index = IndexBuilder.get_default_high_quality_index(dataset) vector_index = IndexBuilder.get_default_high_quality_index(dataset)
try: try:
vector_index.delete()
vector_index.delete_by_group_id(dataset.id)
except Exception: except Exception:
logging.exception("Delete doc index failed when dataset deleted.") logging.exception("Delete doc index failed when dataset deleted.")



+ 2
- 2
api/tasks/deal_dataset_vector_index_task.py 查看文件

raise Exception('Dataset not found') raise Exception('Dataset not found')


if action == "remove": if action == "remove":
index = IndexBuilder.get_index(dataset, 'high_quality', ignore_high_quality_check=False)
index.delete()
index = IndexBuilder.get_index(dataset, 'high_quality', ignore_high_quality_check=True)
index.delete_by_group_id(dataset.id)
elif action == "add": elif action == "add":
dataset_documents = db.session.query(DatasetDocument).filter( dataset_documents = db.session.query(DatasetDocument).filter(
DatasetDocument.dataset_id == dataset_id, DatasetDocument.dataset_id == dataset_id,

Loading…
取消
儲存