浏览代码

Feat/improve vector database logic (#1193)

Co-authored-by: jyong <jyong@dify.ai>
tags/0.3.23
Jyong 2 年前
父节点
当前提交
269a465fc4
没有帐户链接到提交者的电子邮件

+ 124
- 17
api/commands.py 查看文件

import random import random
import string import string
import time import time
import uuid


import click import click
from tqdm import tqdm from tqdm import tqdm
from extensions.ext_database import db from extensions.ext_database import db
from libs.rsa import generate_key_pair from libs.rsa import generate_key_pair
from models.account import InvitationCode, Tenant, TenantAccountJoin from models.account import InvitationCode, Tenant, TenantAccountJoin
from models.dataset import Dataset, DatasetQuery, Document
from models.dataset import Dataset, DatasetQuery, Document, DatasetCollectionBinding
from models.model import Account, AppModelConfig, App from models.model import Account, AppModelConfig, App
import secrets import secrets
import base64 import base64
kw_index = IndexBuilder.get_index(dataset, 'economy') kw_index = IndexBuilder.get_index(dataset, 'economy')
# delete from vector index # delete from vector index
if vector_index: if vector_index:
vector_index.delete()
if dataset.collection_binding_id:
vector_index.delete_by_group_id(dataset.id)
else:
if dataset.collection_binding_id:
vector_index.delete_by_group_id(dataset.id)
else:
vector_index.delete()
kw_index.delete() kw_index.delete()
# update document # update document
update_params = { update_params = {
is_valid=True, is_valid=True,
) )
model_provider = OpenAIProvider(provider=provider) model_provider = OpenAIProvider(provider=provider)
embedding_model = OpenAIEmbedding(name="text-embedding-ada-002", model_provider=model_provider)
embedding_model = OpenAIEmbedding(name="text-embedding-ada-002",
model_provider=model_provider)
embeddings = CacheEmbedding(embedding_model) embeddings = CacheEmbedding(embedding_model)


from core.index.vector_index.qdrant_vector_index import QdrantVectorIndex, QdrantConfig from core.index.vector_index.qdrant_vector_index import QdrantVectorIndex, QdrantConfig
index.create_qdrant_dataset(dataset) index.create_qdrant_dataset(dataset)
index_struct = { index_struct = {
"type": 'qdrant', "type": 'qdrant',
"vector_store": {"class_prefix": dataset.index_struct_dict['vector_store']['class_prefix']}
"vector_store": {
"class_prefix": dataset.index_struct_dict['vector_store']['class_prefix']}
} }
dataset.index_struct = json.dumps(index_struct) dataset.index_struct = json.dumps(index_struct)
db.session.commit() db.session.commit()
click.echo('passed.') click.echo('passed.')
except Exception as e: except Exception as e:
click.echo( click.echo(
click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)), fg='red'))
click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)),
fg='red'))
continue continue


click.echo(click.style('Congratulations! Create {} dataset indexes.'.format(create_count), fg='green')) click.echo(click.style('Congratulations! Create {} dataset indexes.'.format(create_count), fg='green'))
is_valid=True, is_valid=True,
) )
model_provider = OpenAIProvider(provider=provider) model_provider = OpenAIProvider(provider=provider)
embedding_model = OpenAIEmbedding(name="text-embedding-ada-002", model_provider=model_provider)
embedding_model = OpenAIEmbedding(name="text-embedding-ada-002",
model_provider=model_provider)
embeddings = CacheEmbedding(embedding_model) embeddings = CacheEmbedding(embedding_model)


from core.index.vector_index.qdrant_vector_index import QdrantVectorIndex, QdrantConfig from core.index.vector_index.qdrant_vector_index import QdrantVectorIndex, QdrantConfig
click.echo('passed.') click.echo('passed.')
except Exception as e: except Exception as e:
click.echo( click.echo(
click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)), fg='red'))
click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)),
fg='red'))
continue continue


click.echo(click.style('Congratulations! Update {} dataset indexes.'.format(create_count), fg='green')) click.echo(click.style('Congratulations! Update {} dataset indexes.'.format(create_count), fg='green'))



@click.command('normalization-collections', help='restore all collections in one')
def normalization_collections():
click.echo(click.style('Start normalization collections.', fg='green'))
normalization_count = 0

page = 1
while True:
try:
datasets = db.session.query(Dataset).filter(Dataset.indexing_technique == 'high_quality') \
.order_by(Dataset.created_at.desc()).paginate(page=page, per_page=50)
except NotFound:
break

page += 1
for dataset in datasets:
if not dataset.collection_binding_id:
try:
click.echo('restore dataset index: {}'.format(dataset.id))
try:
embedding_model = ModelFactory.get_embedding_model(
tenant_id=dataset.tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
)
except Exception:
provider = Provider(
id='provider_id',
tenant_id=dataset.tenant_id,
provider_name='openai',
provider_type=ProviderType.CUSTOM.value,
encrypted_config=json.dumps({'openai_api_key': 'TEST'}),
is_valid=True,
)
model_provider = OpenAIProvider(provider=provider)
embedding_model = OpenAIEmbedding(name="text-embedding-ada-002",
model_provider=model_provider)
embeddings = CacheEmbedding(embedding_model)
dataset_collection_binding = db.session.query(DatasetCollectionBinding). \
filter(DatasetCollectionBinding.provider_name == embedding_model.model_provider.provider_name,
DatasetCollectionBinding.model_name == embedding_model.name). \
order_by(DatasetCollectionBinding.created_at). \
first()

if not dataset_collection_binding:
dataset_collection_binding = DatasetCollectionBinding(
provider_name=embedding_model.model_provider.provider_name,
model_name=embedding_model.name,
collection_name="Vector_index_" + str(uuid.uuid4()).replace("-", "_") + '_Node'
)
db.session.add(dataset_collection_binding)
db.session.commit()

from core.index.vector_index.qdrant_vector_index import QdrantVectorIndex, QdrantConfig

index = QdrantVectorIndex(
dataset=dataset,
config=QdrantConfig(
endpoint=current_app.config.get('QDRANT_URL'),
api_key=current_app.config.get('QDRANT_API_KEY'),
root_path=current_app.root_path
),
embeddings=embeddings
)
if index:
index.restore_dataset_in_one(dataset, dataset_collection_binding)
else:
click.echo('passed.')

original_index = QdrantVectorIndex(
dataset=dataset,
config=QdrantConfig(
endpoint=current_app.config.get('QDRANT_URL'),
api_key=current_app.config.get('QDRANT_API_KEY'),
root_path=current_app.root_path
),
embeddings=embeddings
)
if original_index:
original_index.delete_original_collection(dataset, dataset_collection_binding)
normalization_count += 1
else:
click.echo('passed.')
except Exception as e:
click.echo(
click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)),
fg='red'))
continue

click.echo(click.style('Congratulations! restore {} dataset indexes.'.format(normalization_count), fg='green'))


@click.command('update_app_model_configs', help='Migrate data to support paragraph variable.') @click.command('update_app_model_configs', help='Migrate data to support paragraph variable.')
@click.option("--batch-size", default=500, help="Number of records to migrate in each batch.") @click.option("--batch-size", default=500, help="Number of records to migrate in each batch.")
def update_app_model_configs(batch_size): def update_app_model_configs(batch_size):
.join(App, App.app_model_config_id == AppModelConfig.id) \ .join(App, App.app_model_config_id == AppModelConfig.id) \
.filter(App.mode == 'completion') \ .filter(App.mode == 'completion') \
.count() .count()
if total_records == 0: if total_records == 0:
click.secho("No data to migrate.", fg='green') click.secho("No data to migrate.", fg='green')
return return
offset = i * batch_size offset = i * batch_size
limit = min(batch_size, total_records - offset) limit = min(batch_size, total_records - offset)


click.secho(f"Fetching batch {i+1}/{num_batches} from source database...", fg='green')
click.secho(f"Fetching batch {i + 1}/{num_batches} from source database...", fg='green')
data_batch = db.session.query(AppModelConfig) \ data_batch = db.session.query(AppModelConfig) \
.join(App, App.app_model_config_id == AppModelConfig.id) \ .join(App, App.app_model_config_id == AppModelConfig.id) \
.filter(App.mode == 'completion') \ .filter(App.mode == 'completion') \
.order_by(App.created_at) \ .order_by(App.created_at) \
.offset(offset).limit(limit).all() .offset(offset).limit(limit).all()
if not data_batch: if not data_batch:
click.secho("No more data to migrate.", fg='green') click.secho("No more data to migrate.", fg='green')
break break
app_data = db.session.query(App) \ app_data = db.session.query(App) \
.filter(App.id == data.app_id) \ .filter(App.id == data.app_id) \
.one() .one()
account_data = db.session.query(Account) \ account_data = db.session.query(Account) \
.join(TenantAccountJoin, Account.id == TenantAccountJoin.account_id) \ .join(TenantAccountJoin, Account.id == TenantAccountJoin.account_id) \
.filter(TenantAccountJoin.role == 'owner') \ .filter(TenantAccountJoin.role == 'owner') \
db.session.commit() db.session.commit()


except Exception as e: except Exception as e:
click.secho(f"Error while migrating data: {e}, app_id: {data.app_id}, app_model_config_id: {data.id}", fg='red')
click.secho(f"Error while migrating data: {e}, app_id: {data.app_id}, app_model_config_id: {data.id}",
fg='red')
continue continue
click.secho(f"Successfully migrated batch {i+1}/{num_batches}.", fg='green')
click.secho(f"Successfully migrated batch {i + 1}/{num_batches}.", fg='green')
pbar.update(len(data_batch)) pbar.update(len(data_batch))



def register_commands(app): def register_commands(app):
app.cli.add_command(reset_password) app.cli.add_command(reset_password)
app.cli.add_command(reset_email) app.cli.add_command(reset_email)
app.cli.add_command(clean_unused_dataset_indexes) app.cli.add_command(clean_unused_dataset_indexes)
app.cli.add_command(create_qdrant_indexes) app.cli.add_command(create_qdrant_indexes)
app.cli.add_command(update_qdrant_indexes) app.cli.add_command(update_qdrant_indexes)
app.cli.add_command(update_app_model_configs)
app.cli.add_command(update_app_model_configs)
app.cli.add_command(normalization_collections)

+ 8
- 0
api/core/index/base.py 查看文件

def create(self, texts: list[Document], **kwargs) -> BaseIndex: def create(self, texts: list[Document], **kwargs) -> BaseIndex:
raise NotImplementedError raise NotImplementedError


@abstractmethod
def create_with_collection_name(self, texts: list[Document], collection_name: str, **kwargs) -> BaseIndex:
raise NotImplementedError

@abstractmethod @abstractmethod
def add_texts(self, texts: list[Document], **kwargs): def add_texts(self, texts: list[Document], **kwargs):
raise NotImplementedError raise NotImplementedError
def delete_by_ids(self, ids: list[str]) -> None: def delete_by_ids(self, ids: list[str]) -> None:
raise NotImplementedError raise NotImplementedError


@abstractmethod
def delete_by_group_id(self, group_id: str) -> None:
raise NotImplementedError

@abstractmethod @abstractmethod
def delete_by_document_id(self, document_id: str): def delete_by_document_id(self, document_id: str):
raise NotImplementedError raise NotImplementedError

+ 32
- 0
api/core/index/keyword_table_index/keyword_table_index.py 查看文件



return self return self


def create_with_collection_name(self, texts: list[Document], collection_name: str, **kwargs) -> BaseIndex:
keyword_table_handler = JiebaKeywordTableHandler()
keyword_table = {}
for text in texts:
keywords = keyword_table_handler.extract_keywords(text.page_content, self._config.max_keywords_per_chunk)
self._update_segment_keywords(self.dataset.id, text.metadata['doc_id'], list(keywords))
keyword_table = self._add_text_to_keyword_table(keyword_table, text.metadata['doc_id'], list(keywords))

dataset_keyword_table = DatasetKeywordTable(
dataset_id=self.dataset.id,
keyword_table=json.dumps({
'__type__': 'keyword_table',
'__data__': {
"index_id": self.dataset.id,
"summary": None,
"table": {}
}
}, cls=SetEncoder)
)
db.session.add(dataset_keyword_table)
db.session.commit()

self._save_dataset_keyword_table(keyword_table)

return self

def add_texts(self, texts: list[Document], **kwargs): def add_texts(self, texts: list[Document], **kwargs):
keyword_table_handler = JiebaKeywordTableHandler() keyword_table_handler = JiebaKeywordTableHandler()


db.session.delete(dataset_keyword_table) db.session.delete(dataset_keyword_table)
db.session.commit() db.session.commit()


def delete_by_group_id(self, group_id: str) -> None:
dataset_keyword_table = self.dataset.dataset_keyword_table
if dataset_keyword_table:
db.session.delete(dataset_keyword_table)
db.session.commit()

def _save_dataset_keyword_table(self, keyword_table): def _save_dataset_keyword_table(self, keyword_table):
keyword_table_dict = { keyword_table_dict = {
'__type__': 'keyword_table', '__type__': 'keyword_table',

+ 57
- 1
api/core/index/vector_index/base.py 查看文件



from core.index.base import BaseIndex from core.index.base import BaseIndex
from extensions.ext_database import db from extensions.ext_database import db
from models.dataset import Dataset, DocumentSegment
from models.dataset import Dataset, DocumentSegment, DatasetCollectionBinding
from models.dataset import Document as DatasetDocument from models.dataset import Document as DatasetDocument




for node_id in ids: for node_id in ids:
vector_store.del_text(node_id) vector_store.del_text(node_id)


def delete_by_group_id(self, group_id: str) -> None:
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)

vector_store.delete()

def delete(self) -> None: def delete(self) -> None:
vector_store = self._get_vector_store() vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store) vector_store = cast(self._get_vector_store_class(), vector_store)
raise e raise e


logging.info(f"Dataset {dataset.id} recreate successfully.") logging.info(f"Dataset {dataset.id} recreate successfully.")

def restore_dataset_in_one(self, dataset: Dataset, dataset_collection_binding: DatasetCollectionBinding):
logging.info(f"restore dataset in_one,_dataset {dataset.id}")

dataset_documents = db.session.query(DatasetDocument).filter(
DatasetDocument.dataset_id == dataset.id,
DatasetDocument.indexing_status == 'completed',
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
).all()

documents = []
for dataset_document in dataset_documents:
segments = db.session.query(DocumentSegment).filter(
DocumentSegment.document_id == dataset_document.id,
DocumentSegment.status == 'completed',
DocumentSegment.enabled == True
).all()

for segment in segments:
document = Document(
page_content=segment.content,
metadata={
"doc_id": segment.index_node_id,
"doc_hash": segment.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
}
)

documents.append(document)

if documents:
try:
self.create_with_collection_name(documents, dataset_collection_binding.collection_name)
except Exception as e:
raise e

logging.info(f"Dataset {dataset.id} recreate successfully.")

def delete_original_collection(self, dataset: Dataset, dataset_collection_binding: DatasetCollectionBinding):
logging.info(f"delete original collection: {dataset.id}")

self.delete()

dataset.collection_binding_id = dataset_collection_binding.id
db.session.add(dataset)
db.session.commit()

logging.info(f"Dataset {dataset.id} recreate successfully.")

+ 13
- 0
api/core/index/vector_index/milvus_vector_index.py 查看文件



return self return self


def create_with_collection_name(self, texts: list[Document], collection_name: str, **kwargs) -> BaseIndex:
uuids = self._get_uuids(texts)
self._vector_store = WeaviateVectorStore.from_documents(
texts,
self._embeddings,
client=self._client,
index_name=collection_name,
uuids=uuids,
by_text=False
)

return self

def _get_vector_store(self) -> VectorStore: def _get_vector_store(self) -> VectorStore:
"""Only for created index.""" """Only for created index."""
if self._vector_store: if self._vector_store:

+ 38
- 2
api/core/index/vector_index/qdrant.py 查看文件

from langchain.embeddings.base import Embeddings from langchain.embeddings.base import Embeddings
from langchain.vectorstores import VectorStore from langchain.vectorstores import VectorStore
from langchain.vectorstores.utils import maximal_marginal_relevance from langchain.vectorstores.utils import maximal_marginal_relevance
from qdrant_client.http.models import PayloadSchemaType


if TYPE_CHECKING: if TYPE_CHECKING:
from qdrant_client import grpc # noqa from qdrant_client import grpc # noqa


CONTENT_KEY = "page_content" CONTENT_KEY = "page_content"
METADATA_KEY = "metadata" METADATA_KEY = "metadata"
GROUP_KEY = "group_id"
VECTOR_NAME = None VECTOR_NAME = None


def __init__( def __init__(
embeddings: Optional[Embeddings] = None, embeddings: Optional[Embeddings] = None,
content_payload_key: str = CONTENT_KEY, content_payload_key: str = CONTENT_KEY,
metadata_payload_key: str = METADATA_KEY, metadata_payload_key: str = METADATA_KEY,
group_payload_key: str = GROUP_KEY,
group_id: str = None,
distance_strategy: str = "COSINE", distance_strategy: str = "COSINE",
vector_name: Optional[str] = VECTOR_NAME, vector_name: Optional[str] = VECTOR_NAME,
embedding_function: Optional[Callable] = None, # deprecated embedding_function: Optional[Callable] = None, # deprecated
is_new_collection: bool = False
): ):
"""Initialize with necessary components.""" """Initialize with necessary components."""
try: try:
self.collection_name = collection_name self.collection_name = collection_name
self.content_payload_key = content_payload_key or self.CONTENT_KEY self.content_payload_key = content_payload_key or self.CONTENT_KEY
self.metadata_payload_key = metadata_payload_key or self.METADATA_KEY self.metadata_payload_key = metadata_payload_key or self.METADATA_KEY
self.group_payload_key = group_payload_key or self.GROUP_KEY
self.vector_name = vector_name or self.VECTOR_NAME self.vector_name = vector_name or self.VECTOR_NAME
self.group_id = group_id
self.is_new_collection= is_new_collection


if embedding_function is not None: if embedding_function is not None:
warnings.warn( warnings.warn(
batch_size: batch_size:
How many vectors upload per-request. How many vectors upload per-request.
Default: 64 Default: 64
group_id:
collection group


Returns: Returns:
List of ids from adding the texts into the vectorstore. List of ids from adding the texts into the vectorstore.
collection_name=self.collection_name, points=points, **kwargs collection_name=self.collection_name, points=points, **kwargs
) )
added_ids.extend(batch_ids) added_ids.extend(batch_ids)

# if is new collection, create payload index on group_id
if self.is_new_collection:
self.client.create_payload_index(self.collection_name, self.group_payload_key,
field_schema=PayloadSchemaType.KEYWORD,
field_type=PayloadSchemaType.KEYWORD)
return added_ids return added_ids


@sync_call_fallback @sync_call_fallback
distance_func: str = "Cosine", distance_func: str = "Cosine",
content_payload_key: str = CONTENT_KEY, content_payload_key: str = CONTENT_KEY,
metadata_payload_key: str = METADATA_KEY, metadata_payload_key: str = METADATA_KEY,
group_payload_key: str = GROUP_KEY,
group_id: str = None,
vector_name: Optional[str] = VECTOR_NAME, vector_name: Optional[str] = VECTOR_NAME,
batch_size: int = 64, batch_size: int = 64,
shard_number: Optional[int] = None, shard_number: Optional[int] = None,
metadata_payload_key: metadata_payload_key:
A payload key used to store the metadata of the document. A payload key used to store the metadata of the document.
Default: "metadata" Default: "metadata"
group_payload_key:
A payload key used to store the content of the document.
Default: "group_id"
group_id:
collection group id
vector_name: vector_name:
Name of the vector to be used internally in Qdrant. Name of the vector to be used internally in Qdrant.
Default: None Default: None
distance_func, distance_func,
content_payload_key, content_payload_key,
metadata_payload_key, metadata_payload_key,
group_payload_key,
group_id,
vector_name, vector_name,
shard_number, shard_number,
replication_factor, replication_factor,
distance_func: str = "Cosine", distance_func: str = "Cosine",
content_payload_key: str = CONTENT_KEY, content_payload_key: str = CONTENT_KEY,
metadata_payload_key: str = METADATA_KEY, metadata_payload_key: str = METADATA_KEY,
group_payload_key: str = GROUP_KEY,
group_id: str = None,
vector_name: Optional[str] = VECTOR_NAME, vector_name: Optional[str] = VECTOR_NAME,
shard_number: Optional[int] = None, shard_number: Optional[int] = None,
replication_factor: Optional[int] = None, replication_factor: Optional[int] = None,
vector_size = len(partial_embeddings[0]) vector_size = len(partial_embeddings[0])
collection_name = collection_name or uuid.uuid4().hex collection_name = collection_name or uuid.uuid4().hex
distance_func = distance_func.upper() distance_func = distance_func.upper()
is_new_collection = False
client = qdrant_client.QdrantClient( client = qdrant_client.QdrantClient(
location=location, location=location,
url=url, url=url,
init_from=init_from, init_from=init_from,
timeout=timeout, # type: ignore[arg-type] timeout=timeout, # type: ignore[arg-type]
) )
is_new_collection = True
qdrant = cls( qdrant = cls(
client=client, client=client,
collection_name=collection_name, collection_name=collection_name,
metadata_payload_key=metadata_payload_key, metadata_payload_key=metadata_payload_key,
distance_strategy=distance_func, distance_strategy=distance_func,
vector_name=vector_name, vector_name=vector_name,
group_id=group_id,
group_payload_key=group_payload_key,
is_new_collection=is_new_collection
) )
return qdrant return qdrant


metadatas: Optional[List[dict]], metadatas: Optional[List[dict]],
content_payload_key: str, content_payload_key: str,
metadata_payload_key: str, metadata_payload_key: str,
group_id: str,
group_payload_key: str
) -> List[dict]: ) -> List[dict]:
payloads = [] payloads = []
for i, text in enumerate(texts): for i, text in enumerate(texts):
{ {
content_payload_key: text, content_payload_key: text,
metadata_payload_key: metadata, metadata_payload_key: metadata,
group_payload_key: group_id
} }
) )


else: else:
out.append( out.append(
rest.FieldCondition( rest.FieldCondition(
key=f"{self.metadata_payload_key}.{key}",
key=key,
match=rest.MatchValue(value=value), match=rest.MatchValue(value=value),
) )
) )
metadatas: Optional[List[dict]] = None, metadatas: Optional[List[dict]] = None,
ids: Optional[Sequence[str]] = None, ids: Optional[Sequence[str]] = None,
batch_size: int = 64, batch_size: int = 64,
group_id: Optional[str] = None,
) -> Generator[Tuple[List[str], List[rest.PointStruct]], None, None]: ) -> Generator[Tuple[List[str], List[rest.PointStruct]], None, None]:
from qdrant_client.http import models as rest from qdrant_client.http import models as rest


batch_metadatas, batch_metadatas,
self.content_payload_key, self.content_payload_key,
self.metadata_payload_key, self.metadata_payload_key,
self.group_id,
self.group_payload_key
), ),
) )
] ]

+ 58
- 20
api/core/index/vector_index/qdrant_vector_index.py 查看文件

from langchain.schema import Document, BaseRetriever from langchain.schema import Document, BaseRetriever
from langchain.vectorstores import VectorStore from langchain.vectorstores import VectorStore
from pydantic import BaseModel from pydantic import BaseModel
from qdrant_client.http.models import HnswConfigDiff


from core.index.base import BaseIndex from core.index.base import BaseIndex
from core.index.vector_index.base import BaseVectorIndex from core.index.vector_index.base import BaseVectorIndex
from core.vector_store.qdrant_vector_store import QdrantVectorStore from core.vector_store.qdrant_vector_store import QdrantVectorStore
from models.dataset import Dataset
from extensions.ext_database import db
from models.dataset import Dataset, DatasetCollectionBinding




class QdrantConfig(BaseModel): class QdrantConfig(BaseModel):
endpoint: str endpoint: str
api_key: Optional[str] api_key: Optional[str]
root_path: Optional[str] root_path: Optional[str]
def to_qdrant_params(self): def to_qdrant_params(self):
if self.endpoint and self.endpoint.startswith('path:'): if self.endpoint and self.endpoint.startswith('path:'):
path = self.endpoint.replace('path:', '') path = self.endpoint.replace('path:', '')
return 'qdrant' return 'qdrant'


def get_index_name(self, dataset: Dataset) -> str: def get_index_name(self, dataset: Dataset) -> str:
if self.dataset.index_struct_dict:
class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']
if not class_prefix.endswith('_Node'):
# original class_prefix
class_prefix += '_Node'

return class_prefix
if dataset.collection_binding_id:
dataset_collection_binding = db.session.query(DatasetCollectionBinding). \
filter(DatasetCollectionBinding.id == dataset.collection_binding_id). \
one_or_none()
if dataset_collection_binding:
return dataset_collection_binding.collection_name
else:
raise ValueError('Dataset Collection Bindings is not exist!')
else:
if self.dataset.index_struct_dict:
class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']
return class_prefix


dataset_id = dataset.id
return "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
dataset_id = dataset.id
return "Vector_index_" + dataset_id.replace("-", "_") + '_Node'


def to_index_struct(self) -> dict: def to_index_struct(self) -> dict:
return { return {
collection_name=self.get_index_name(self.dataset), collection_name=self.get_index_name(self.dataset),
ids=uuids, ids=uuids,
content_payload_key='page_content', content_payload_key='page_content',
group_id=self.dataset.id,
group_payload_key='group_id',
hnsw_config=HnswConfigDiff(m=0, payload_m=16, ef_construct=100, full_scan_threshold=10000,
max_indexing_threads=0, on_disk=False),
**self._client_config.to_qdrant_params()
)

return self

def create_with_collection_name(self, texts: list[Document], collection_name: str, **kwargs) -> BaseIndex:
uuids = self._get_uuids(texts)
self._vector_store = QdrantVectorStore.from_documents(
texts,
self._embeddings,
collection_name=collection_name,
ids=uuids,
content_payload_key='page_content',
group_id=self.dataset.id,
group_payload_key='group_id',
hnsw_config=HnswConfigDiff(m=0, payload_m=16, ef_construct=100, full_scan_threshold=10000,
max_indexing_threads=0, on_disk=False),
**self._client_config.to_qdrant_params() **self._client_config.to_qdrant_params()
) )


if self._vector_store: if self._vector_store:
return self._vector_store return self._vector_store
attributes = ['doc_id', 'dataset_id', 'document_id'] attributes = ['doc_id', 'dataset_id', 'document_id']
if self._is_origin():
attributes = ['doc_id']
client = qdrant_client.QdrantClient( client = qdrant_client.QdrantClient(
**self._client_config.to_qdrant_params() **self._client_config.to_qdrant_params()
) )
client=client, client=client,
collection_name=self.get_index_name(self.dataset), collection_name=self.get_index_name(self.dataset),
embeddings=self._embeddings, embeddings=self._embeddings,
content_payload_key='page_content'
content_payload_key='page_content',
group_id=self.dataset.id,
group_payload_key='group_id'
) )


def _get_vector_store_class(self) -> type: def _get_vector_store_class(self) -> type:
return QdrantVectorStore return QdrantVectorStore


def delete_by_document_id(self, document_id: str): def delete_by_document_id(self, document_id: str):
if self._is_origin():
self.recreate_dataset(self.dataset)
return


vector_store = self._get_vector_store() vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store) vector_store = cast(self._get_vector_store_class(), vector_store)
)) ))


def delete_by_ids(self, ids: list[str]) -> None: def delete_by_ids(self, ids: list[str]) -> None:
if self._is_origin():
self.recreate_dataset(self.dataset)
return


vector_store = self._get_vector_store() vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store) vector_store = cast(self._get_vector_store_class(), vector_store)
], ],
)) ))


def delete_by_group_id(self, group_id: str) -> None:

vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)

from qdrant_client.http import models
vector_store.del_texts(models.Filter(
must=[
models.FieldCondition(
key="group_id",
match=models.MatchValue(value=group_id),
),
],
))


def _is_origin(self): def _is_origin(self):
if self.dataset.index_struct_dict: if self.dataset.index_struct_dict:
class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix'] class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']

+ 14
- 0
api/core/index/vector_index/weaviate_vector_index.py 查看文件



return self return self


def create_with_collection_name(self, texts: list[Document], collection_name: str, **kwargs) -> BaseIndex:
uuids = self._get_uuids(texts)
self._vector_store = WeaviateVectorStore.from_documents(
texts,
self._embeddings,
client=self._client,
index_name=self.get_index_name(self.dataset),
uuids=uuids,
by_text=False
)

return self


def _get_vector_store(self) -> VectorStore: def _get_vector_store(self) -> VectorStore:
"""Only for created index.""" """Only for created index."""
if self._vector_store: if self._vector_store:

+ 4
- 2
api/core/tool/dataset_retriever_tool.py 查看文件

return_resource: str return_resource: str
retriever_from: str retriever_from: str



@classmethod @classmethod
def from_dataset(cls, dataset: Dataset, **kwargs): def from_dataset(cls, dataset: Dataset, **kwargs):
description = dataset.description description = dataset.description
query, query,
search_type='similarity_score_threshold', search_type='similarity_score_threshold',
search_kwargs={ search_kwargs={
'k': self.k
'k': self.k,
'filter': {
'group_id': [dataset.id]
}
} }
) )
else: else:

+ 5
- 0
api/core/vector_store/qdrant_vector_store.py 查看文件



self.client.delete_collection(collection_name=self.collection_name) self.client.delete_collection(collection_name=self.collection_name)


def delete_group(self):
self._reload_if_needed()

self.client.delete_collection(collection_name=self.collection_name)

@classmethod @classmethod
def _document_from_scored_point( def _document_from_scored_point(
cls, cls,

+ 47
- 0
api/migrations/versions/6e2cfb077b04_add_dataset_collection_binding.py 查看文件

"""add_dataset_collection_binding

Revision ID: 6e2cfb077b04
Revises: 77e83833755c
Create Date: 2023-09-13 22:16:48.027810

"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql

# revision identifiers, used by Alembic.
revision = '6e2cfb077b04'
down_revision = '77e83833755c'
branch_labels = None
depends_on = None


def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('dataset_collection_bindings',
sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('provider_name', sa.String(length=40), nullable=False),
sa.Column('model_name', sa.String(length=40), nullable=False),
sa.Column('collection_name', sa.String(length=64), nullable=False),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
sa.PrimaryKeyConstraint('id', name='dataset_collection_bindings_pkey')
)
with op.batch_alter_table('dataset_collection_bindings', schema=None) as batch_op:
batch_op.create_index('provider_model_name_idx', ['provider_name', 'model_name'], unique=False)

with op.batch_alter_table('datasets', schema=None) as batch_op:
batch_op.add_column(sa.Column('collection_binding_id', postgresql.UUID(), nullable=True))

# ### end Alembic commands ###


def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('datasets', schema=None) as batch_op:
batch_op.drop_column('collection_binding_id')

with op.batch_alter_table('dataset_collection_bindings', schema=None) as batch_op:
batch_op.drop_index('provider_model_name_idx')

op.drop_table('dataset_collection_bindings')
# ### end Alembic commands ###

+ 18
- 0
api/models/dataset.py 查看文件

server_default=db.text('CURRENT_TIMESTAMP(0)')) server_default=db.text('CURRENT_TIMESTAMP(0)'))
embedding_model = db.Column(db.String(255), nullable=True) embedding_model = db.Column(db.String(255), nullable=True)
embedding_model_provider = db.Column(db.String(255), nullable=True) embedding_model_provider = db.Column(db.String(255), nullable=True)
collection_binding_id = db.Column(UUID, nullable=True)



@property @property
def dataset_keyword_table(self): def dataset_keyword_table(self):


def get_embedding(self) -> list[float]: def get_embedding(self) -> list[float]:
return pickle.loads(self.embedding) return pickle.loads(self.embedding)


class DatasetCollectionBinding(db.Model):
__tablename__ = 'dataset_collection_bindings'
__table_args__ = (
db.PrimaryKeyConstraint('id', name='dataset_collection_bindings_pkey'),
db.Index('provider_model_name_idx', 'provider_name', 'model_name')

)

id = db.Column(UUID, primary_key=True, server_default=db.text('uuid_generate_v4()'))
provider_name = db.Column(db.String(40), nullable=False)
model_name = db.Column(db.String(40), nullable=False)
collection_name = db.Column(db.String(64), nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))


+ 41
- 3
api/services/dataset_service.py 查看文件

from extensions.ext_database import db from extensions.ext_database import db
from libs import helper from libs import helper
from models.account import Account from models.account import Account
from models.dataset import Dataset, Document, DatasetQuery, DatasetProcessRule, AppDatasetJoin, DocumentSegment
from models.dataset import Dataset, Document, DatasetQuery, DatasetProcessRule, AppDatasetJoin, DocumentSegment, \
DatasetCollectionBinding
from models.model import UploadFile from models.model import UploadFile
from models.source import DataSourceBinding from models.source import DataSourceBinding
from services.errors.account import NoPermissionError from services.errors.account import NoPermissionError
action = 'remove' action = 'remove'
filtered_data['embedding_model'] = None filtered_data['embedding_model'] = None
filtered_data['embedding_model_provider'] = None filtered_data['embedding_model_provider'] = None
filtered_data['collection_binding_id'] = None
elif data['indexing_technique'] == 'high_quality': elif data['indexing_technique'] == 'high_quality':
action = 'add' action = 'add'
# get embedding model setting # get embedding model setting
) )
filtered_data['embedding_model'] = embedding_model.name filtered_data['embedding_model'] = embedding_model.name
filtered_data['embedding_model_provider'] = embedding_model.model_provider.provider_name filtered_data['embedding_model_provider'] = embedding_model.model_provider.provider_name
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
embedding_model.model_provider.provider_name,
embedding_model.name
)
filtered_data['collection_binding_id'] = dataset_collection_binding.id
except LLMBadRequestError: except LLMBadRequestError:
raise ValueError( raise ValueError(
f"No Embedding Model available. Please configure a valid provider " f"No Embedding Model available. Please configure a valid provider "
) )
dataset.embedding_model = embedding_model.name dataset.embedding_model = embedding_model.name
dataset.embedding_model_provider = embedding_model.model_provider.provider_name dataset.embedding_model_provider = embedding_model.model_provider.provider_name

dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
embedding_model.model_provider.provider_name,
embedding_model.name
)
dataset.collection_binding_id = dataset_collection_binding.id


documents = [] documents = []
batch = time.strftime('%Y%m%d%H%M%S') + str(random.randint(100000, 999999)) batch = time.strftime('%Y%m%d%H%M%S') + str(random.randint(100000, 999999))
if total_count > tenant_document_count: if total_count > tenant_document_count:
raise ValueError(f"All your documents have overed limit {tenant_document_count}.") raise ValueError(f"All your documents have overed limit {tenant_document_count}.")
embedding_model = None embedding_model = None
dataset_collection_binding_id = None
if document_data['indexing_technique'] == 'high_quality': if document_data['indexing_technique'] == 'high_quality':
embedding_model = ModelFactory.get_embedding_model( embedding_model = ModelFactory.get_embedding_model(
tenant_id=tenant_id tenant_id=tenant_id
) )
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
embedding_model.model_provider.provider_name,
embedding_model.name
)
dataset_collection_binding_id = dataset_collection_binding.id
# save dataset # save dataset
dataset = Dataset( dataset = Dataset(
tenant_id=tenant_id, tenant_id=tenant_id,
indexing_technique=document_data["indexing_technique"], indexing_technique=document_data["indexing_technique"],
created_by=account.id, created_by=account.id,
embedding_model=embedding_model.name if embedding_model else None, embedding_model=embedding_model.name if embedding_model else None,
embedding_model_provider=embedding_model.model_provider.provider_name if embedding_model else None
embedding_model_provider=embedding_model.model_provider.provider_name if embedding_model else None,
collection_binding_id=dataset_collection_binding_id
) )


db.session.add(dataset) db.session.add(dataset)
delete_segment_from_index_task.delay(segment.id, segment.index_node_id, dataset.id, document.id) delete_segment_from_index_task.delay(segment.id, segment.index_node_id, dataset.id, document.id)
db.session.delete(segment) db.session.delete(segment)
db.session.commit() db.session.commit()


class DatasetCollectionBindingService:
@classmethod
def get_dataset_collection_binding(cls, provider_name: str, model_name: str) -> DatasetCollectionBinding:
dataset_collection_binding = db.session.query(DatasetCollectionBinding). \
filter(DatasetCollectionBinding.provider_name == provider_name,
DatasetCollectionBinding.model_name == model_name). \
order_by(DatasetCollectionBinding.created_at). \
first()

if not dataset_collection_binding:
dataset_collection_binding = DatasetCollectionBinding(
provider_name=provider_name,
model_name=model_name,
collection_name="Vector_index_" + str(uuid.uuid4()).replace("-", "_") + '_Node'
)
db.session.add(dataset_collection_binding)
db.session.flush()
return dataset_collection_binding

+ 4
- 1
api/services/hit_testing_service.py 查看文件

query, query,
search_type='similarity_score_threshold', search_type='similarity_score_threshold',
search_kwargs={ search_kwargs={
'k': 10
'k': 10,
'filter': {
'group_id': [dataset.id]
}
} }
) )
end = time.perf_counter() end = time.perf_counter()

正在加载...
取消
保存