Ver código fonte

refactor(rag): switch to dify_config. (#6410)

Co-authored-by: -LAN- <laipz8200@outlook.com>
tags/0.6.15
Poorandy 1 ano atrás
pai
commit
c8f5dfcf17
Nenhuma conta vinculada ao e-mail do autor do commit

+ 19
- 18
api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py Ver arquivo

"`alibabacloud_gpdb20160503` and `alibabacloud_tea_openapi` packages not found, " "`alibabacloud_gpdb20160503` and `alibabacloud_tea_openapi` packages not found, "
"please run `pip install alibabacloud_gpdb20160503 alibabacloud_tea_openapi`" "please run `pip install alibabacloud_gpdb20160503 alibabacloud_tea_openapi`"
) )
from flask import current_app


from configs import dify_config
from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.entity.embedding import Embeddings
from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
"region_id": self.region_id, "region_id": self.region_id,
"read_timeout": self.read_timeout, "read_timeout": self.read_timeout,
} }
class AnalyticdbVector(BaseVector): class AnalyticdbVector(BaseVector):
_instance = None _instance = None
_init = False _init = False
if cls._instance is None: if cls._instance is None:
cls._instance = super().__new__(cls) cls._instance = super().__new__(cls)
return cls._instance return cls._instance
def __init__(self, collection_name: str, config: AnalyticdbConfig): def __init__(self, collection_name: str, config: AnalyticdbConfig):
# collection_name must be updated every time # collection_name must be updated every time
self._collection_name = collection_name.lower() self._collection_name = collection_name.lower()
raise ValueError( raise ValueError(
f"failed to create namespace {self.config.namespace}: {e}" f"failed to create namespace {self.config.namespace}: {e}"
) )
def _create_collection_if_not_exists(self, embedding_dimension: int): def _create_collection_if_not_exists(self, embedding_dimension: int):
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
from Tea.exceptions import TeaException from Tea.exceptions import TeaException


def get_type(self) -> str: def get_type(self) -> str:
return VectorType.ANALYTICDB return VectorType.ANALYTICDB
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
dimension = len(embeddings[0]) dimension = len(embeddings[0])
self._create_collection_if_not_exists(dimension) self._create_collection_if_not_exists(dimension)
) )
response = self._client.query_collection_data(request) response = self._client.query_collection_data(request)
return len(response.body.matches.match) > 0 return len(response.body.matches.match) > 0
def delete_by_ids(self, ids: list[str]) -> None: def delete_by_ids(self, ids: list[str]) -> None:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
ids_str = ",".join(f"'{id}'" for id in ids) ids_str = ",".join(f"'{id}'" for id in ids)
) )
documents.append(doc) documents.append(doc)
return documents return documents
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
score_threshold = ( score_threshold = (
) )
documents.append(doc) documents.append(doc)
return documents return documents
def delete(self) -> None: def delete(self) -> None:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
request = gpdb_20160503_models.DeleteCollectionRequest( request = gpdb_20160503_models.DeleteCollectionRequest(
dataset.index_struct = json.dumps( dataset.index_struct = json.dumps(
self.gen_index_struct_dict(VectorType.ANALYTICDB, collection_name) self.gen_index_struct_dict(VectorType.ANALYTICDB, collection_name)
) )
config = current_app.config

# TODO handle optional params
return AnalyticdbVector( return AnalyticdbVector(
collection_name, collection_name,
AnalyticdbConfig( AnalyticdbConfig(
access_key_id=config.get("ANALYTICDB_KEY_ID"),
access_key_secret=config.get("ANALYTICDB_KEY_SECRET"),
region_id=config.get("ANALYTICDB_REGION_ID"),
instance_id=config.get("ANALYTICDB_INSTANCE_ID"),
account=config.get("ANALYTICDB_ACCOUNT"),
account_password=config.get("ANALYTICDB_PASSWORD"),
namespace=config.get("ANALYTICDB_NAMESPACE"),
namespace_password=config.get("ANALYTICDB_NAMESPACE_PASSWORD"),
access_key_id=dify_config.ANALYTICDB_KEY_ID,
access_key_secret=dify_config.ANALYTICDB_KEY_SECRET,
region_id=dify_config.ANALYTICDB_REGION_ID,
instance_id=dify_config.ANALYTICDB_INSTANCE_ID,
account=dify_config.ANALYTICDB_ACCOUNT,
account_password=dify_config.ANALYTICDB_PASSWORD,
namespace=dify_config.ANALYTICDB_NAMESPACE,
namespace_password=dify_config.ANALYTICDB_NAMESPACE_PASSWORD,
), ),
)
)

+ 7
- 8
api/core/rag/datasource/vdb/chroma/chroma_vector.py Ver arquivo



import chromadb import chromadb
from chromadb import QueryResult, Settings from chromadb import QueryResult, Settings
from flask import current_app
from pydantic import BaseModel from pydantic import BaseModel


from configs import dify_config
from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.entity.embedding import Embeddings
from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
} }
dataset.index_struct = json.dumps(index_struct_dict) dataset.index_struct = json.dumps(index_struct_dict)


config = current_app.config
return ChromaVector( return ChromaVector(
collection_name=collection_name, collection_name=collection_name,
config=ChromaConfig( config=ChromaConfig(
host=config.get('CHROMA_HOST'),
port=int(config.get('CHROMA_PORT')),
tenant=config.get('CHROMA_TENANT', chromadb.DEFAULT_TENANT),
database=config.get('CHROMA_DATABASE', chromadb.DEFAULT_DATABASE),
auth_provider=config.get('CHROMA_AUTH_PROVIDER'),
auth_credentials=config.get('CHROMA_AUTH_CREDENTIALS'),
host=dify_config.CHROMA_HOST,
port=dify_config.CHROMA_PORT,
tenant=dify_config.CHROMA_TENANT or chromadb.DEFAULT_TENANT,
database=dify_config.CHROMA_DATABASE or chromadb.DEFAULT_DATABASE,
auth_provider=dify_config.CHROMA_AUTH_PROVIDER,
auth_credentials=dify_config.CHROMA_AUTH_CREDENTIALS,
), ),
) )

+ 7
- 8
api/core/rag/datasource/vdb/milvus/milvus_vector.py Ver arquivo

from typing import Any, Optional from typing import Any, Optional
from uuid import uuid4 from uuid import uuid4


from flask import current_app
from pydantic import BaseModel, model_validator from pydantic import BaseModel, model_validator
from pymilvus import MilvusClient, MilvusException, connections from pymilvus import MilvusClient, MilvusException, connections


from configs import dify_config
from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.entity.embedding import Embeddings
from core.rag.datasource.vdb.field import Field from core.rag.datasource.vdb.field import Field
from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_base import BaseVector
dataset.index_struct = json.dumps( dataset.index_struct = json.dumps(
self.gen_index_struct_dict(VectorType.MILVUS, collection_name)) self.gen_index_struct_dict(VectorType.MILVUS, collection_name))


config = current_app.config
return MilvusVector( return MilvusVector(
collection_name=collection_name, collection_name=collection_name,
config=MilvusConfig( config=MilvusConfig(
host=config.get('MILVUS_HOST'),
port=config.get('MILVUS_PORT'),
user=config.get('MILVUS_USER'),
password=config.get('MILVUS_PASSWORD'),
secure=config.get('MILVUS_SECURE'),
database=config.get('MILVUS_DATABASE'),
host=dify_config.MILVUS_HOST,
port=dify_config.MILVUS_PORT,
user=dify_config.MILVUS_USER,
password=dify_config.MILVUS_PASSWORD,
secure=dify_config.MILVUS_SECURE,
database=dify_config.MILVUS_DATABASE,
) )
) )

+ 8
- 8
api/core/rag/datasource/vdb/myscale/myscale_vector.py Ver arquivo

from typing import Any from typing import Any


from clickhouse_connect import get_client from clickhouse_connect import get_client
from flask import current_app
from pydantic import BaseModel from pydantic import BaseModel


from configs import dify_config
from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.entity.embedding import Embeddings
from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
dataset.index_struct = json.dumps( dataset.index_struct = json.dumps(
self.gen_index_struct_dict(VectorType.MYSCALE, collection_name)) self.gen_index_struct_dict(VectorType.MYSCALE, collection_name))


config = current_app.config
return MyScaleVector( return MyScaleVector(
collection_name=collection_name, collection_name=collection_name,
config=MyScaleConfig( config=MyScaleConfig(
host=config.get("MYSCALE_HOST", "localhost"),
port=int(config.get("MYSCALE_PORT", 8123)),
user=config.get("MYSCALE_USER", "default"),
password=config.get("MYSCALE_PASSWORD", ""),
database=config.get("MYSCALE_DATABASE", "default"),
fts_params=config.get("MYSCALE_FTS_PARAMS", ""),
# TODO: I think setting those values as the default config would be a better option.
host=dify_config.MYSCALE_HOST or "localhost",
port=dify_config.MYSCALE_PORT or 8123,
user=dify_config.MYSCALE_USER or "default",
password=dify_config.MYSCALE_PASSWORD or "",
database=dify_config.MYSCALE_DATABASE or "default",
fts_params=dify_config.MYSCALE_FTS_PARAMS or "",
), ),
) )

+ 6
- 7
api/core/rag/datasource/vdb/opensearch/opensearch_vector.py Ver arquivo

from typing import Any, Optional from typing import Any, Optional
from uuid import uuid4 from uuid import uuid4


from flask import current_app
from opensearchpy import OpenSearch, helpers from opensearchpy import OpenSearch, helpers
from opensearchpy.helpers import BulkIndexError from opensearchpy.helpers import BulkIndexError
from pydantic import BaseModel, model_validator from pydantic import BaseModel, model_validator


from configs import dify_config
from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.entity.embedding import Embeddings
from core.rag.datasource.vdb.field import Field from core.rag.datasource.vdb.field import Field
from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_base import BaseVector
dataset.index_struct = json.dumps( dataset.index_struct = json.dumps(
self.gen_index_struct_dict(VectorType.OPENSEARCH, collection_name)) self.gen_index_struct_dict(VectorType.OPENSEARCH, collection_name))


config = current_app.config


open_search_config = OpenSearchConfig( open_search_config = OpenSearchConfig(
host=config.get('OPENSEARCH_HOST'),
port=config.get('OPENSEARCH_PORT'),
user=config.get('OPENSEARCH_USER'),
password=config.get('OPENSEARCH_PASSWORD'),
secure=config.get('OPENSEARCH_SECURE'),
host=dify_config.OPENSEARCH_HOST,
port=dify_config.OPENSEARCH_PORT,
user=dify_config.OPENSEARCH_USER,
password=dify_config.OPENSEARCH_PASSWORD,
secure=dify_config.OPENSEARCH_SECURE,
) )


return OpenSearchVector( return OpenSearchVector(

+ 8
- 9
api/core/rag/datasource/vdb/oracle/oraclevector.py Ver arquivo



import numpy import numpy
import oracledb import oracledb
from flask import current_app
from pydantic import BaseModel, model_validator from pydantic import BaseModel, model_validator


from configs import dify_config
from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.entity.embedding import Embeddings
from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory


SQL_CREATE_TABLE = """ SQL_CREATE_TABLE = """
CREATE TABLE IF NOT EXISTS {table_name} ( CREATE TABLE IF NOT EXISTS {table_name} (
id varchar2(100)
id varchar2(100)
,text CLOB NOT NULL ,text CLOB NOT NULL
,meta JSON ,meta JSON
,embedding vector NOT NULL ,embedding vector NOT NULL
)
)
""" """




dataset.index_struct = json.dumps( dataset.index_struct = json.dumps(
self.gen_index_struct_dict(VectorType.ORACLE, collection_name)) self.gen_index_struct_dict(VectorType.ORACLE, collection_name))


config = current_app.config
return OracleVector( return OracleVector(
collection_name=collection_name, collection_name=collection_name,
config=OracleVectorConfig( config=OracleVectorConfig(
host=config.get("ORACLE_HOST"),
port=config.get("ORACLE_PORT"),
user=config.get("ORACLE_USER"),
password=config.get("ORACLE_PASSWORD"),
database=config.get("ORACLE_DATABASE"),
host=dify_config.ORACLE_HOST,
port=dify_config.ORACLE_PORT,
user=dify_config.ORACLE_USER,
password=dify_config.ORACLE_PASSWORD,
database=dify_config.ORACLE_DATABASE,
), ),
) )

+ 9
- 9
api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py Ver arquivo

from typing import Any from typing import Any
from uuid import UUID, uuid4 from uuid import UUID, uuid4


from flask import current_app
from numpy import ndarray from numpy import ndarray
from pgvecto_rs.sqlalchemy import Vector from pgvecto_rs.sqlalchemy import Vector
from pydantic import BaseModel, model_validator from pydantic import BaseModel, model_validator
from sqlalchemy.dialects import postgresql from sqlalchemy.dialects import postgresql
from sqlalchemy.orm import Mapped, Session, mapped_column from sqlalchemy.orm import Mapped, Session, mapped_column


from configs import dify_config
from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.entity.embedding import Embeddings
from core.rag.datasource.vdb.pgvecto_rs.collection import CollectionORM from core.rag.datasource.vdb.pgvecto_rs.collection import CollectionORM
from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_base import BaseVector
text TEXT NOT NULL, text TEXT NOT NULL,
meta JSONB NOT NULL, meta JSONB NOT NULL,
vector vector({dimension}) NOT NULL vector vector({dimension}) NOT NULL
) using heap;
) using heap;
""") """)
session.execute(create_statement) session.execute(create_statement)
index_statement = sql_text(f""" index_statement = sql_text(f"""
dataset.index_struct = json.dumps( dataset.index_struct = json.dumps(
self.gen_index_struct_dict(VectorType.WEAVIATE, collection_name)) self.gen_index_struct_dict(VectorType.WEAVIATE, collection_name))
dim = len(embeddings.embed_query("pgvecto_rs")) dim = len(embeddings.embed_query("pgvecto_rs"))
config = current_app.config
return PGVectoRS( return PGVectoRS(
collection_name=collection_name, collection_name=collection_name,
config=PgvectoRSConfig( config=PgvectoRSConfig(
host=config.get('PGVECTO_RS_HOST'),
port=config.get('PGVECTO_RS_PORT'),
user=config.get('PGVECTO_RS_USER'),
password=config.get('PGVECTO_RS_PASSWORD'),
database=config.get('PGVECTO_RS_DATABASE'),
host=dify_config.PGVECTO_RS_HOST,
port=dify_config.PGVECTO_RS_PORT,
user=dify_config.PGVECTO_RS_USER,
password=dify_config.PGVECTO_RS_PASSWORD,
database=dify_config.PGVECTO_RS_DATABASE,
), ),
dim=dim dim=dim
)
)

+ 8
- 9
api/core/rag/datasource/vdb/pgvector/pgvector.py Ver arquivo



import psycopg2.extras import psycopg2.extras
import psycopg2.pool import psycopg2.pool
from flask import current_app
from pydantic import BaseModel, model_validator from pydantic import BaseModel, model_validator


from configs import dify_config
from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.entity.embedding import Embeddings
from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
text TEXT NOT NULL, text TEXT NOT NULL,
meta JSONB NOT NULL, meta JSONB NOT NULL,
embedding vector({dimension}) NOT NULL embedding vector({dimension}) NOT NULL
) using heap;
) using heap;
""" """




dataset.index_struct = json.dumps( dataset.index_struct = json.dumps(
self.gen_index_struct_dict(VectorType.PGVECTOR, collection_name)) self.gen_index_struct_dict(VectorType.PGVECTOR, collection_name))


config = current_app.config
return PGVector( return PGVector(
collection_name=collection_name, collection_name=collection_name,
config=PGVectorConfig( config=PGVectorConfig(
host=config.get("PGVECTOR_HOST"),
port=config.get("PGVECTOR_PORT"),
user=config.get("PGVECTOR_USER"),
password=config.get("PGVECTOR_PASSWORD"),
database=config.get("PGVECTOR_DATABASE"),
host=dify_config.PGVECTOR_HOST,
port=dify_config.PGVECTOR_PORT,
user=dify_config.PGVECTOR_USER,
password=dify_config.PGVECTOR_PASSWORD,
database=dify_config.PGVECTOR_DATABASE,
), ),
)
)

+ 6
- 5
api/core/rag/datasource/vdb/qdrant/qdrant_vector.py Ver arquivo

) )
from qdrant_client.local.qdrant_local import QdrantLocal from qdrant_client.local.qdrant_local import QdrantLocal


from configs import dify_config
from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.entity.embedding import Embeddings
from core.rag.datasource.vdb.field import Field from core.rag.datasource.vdb.field import Field
from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_base import BaseVector
collection_name=collection_name, collection_name=collection_name,
group_id=dataset.id, group_id=dataset.id,
config=QdrantConfig( config=QdrantConfig(
endpoint=config.get('QDRANT_URL'),
api_key=config.get('QDRANT_API_KEY'),
endpoint=dify_config.QDRANT_URL,
api_key=dify_config.QDRANT_API_KEY,
root_path=config.root_path, root_path=config.root_path,
timeout=config.get('QDRANT_CLIENT_TIMEOUT'),
grpc_port=config.get('QDRANT_GRPC_PORT'),
prefer_grpc=config.get('QDRANT_GRPC_ENABLED')
timeout=dify_config.QDRANT_CLIENT_TIMEOUT,
grpc_port=dify_config.QDRANT_GRPC_PORT,
prefer_grpc=dify_config.QDRANT_GRPC_ENABLED
) )
) )

+ 7
- 8
api/core/rag/datasource/vdb/relyt/relyt_vector.py Ver arquivo

import uuid import uuid
from typing import Any, Optional from typing import Any, Optional


from flask import current_app
from pydantic import BaseModel, model_validator from pydantic import BaseModel, model_validator
from sqlalchemy import Column, Sequence, String, Table, create_engine, insert from sqlalchemy import Column, Sequence, String, Table, create_engine, insert
from sqlalchemy import text as sql_text from sqlalchemy import text as sql_text
except ImportError: except ImportError:
from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.declarative import declarative_base


from configs import dify_config
from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.models.document import Document from core.rag.models.document import Document
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
document TEXT NOT NULL, document TEXT NOT NULL,
metadata JSON NOT NULL, metadata JSON NOT NULL,
embedding vector({dimension}) NOT NULL embedding vector({dimension}) NOT NULL
) using heap;
) using heap;
""") """)
session.execute(create_statement) session.execute(create_statement)
index_statement = sql_text(f""" index_statement = sql_text(f"""
dataset.index_struct = json.dumps( dataset.index_struct = json.dumps(
self.gen_index_struct_dict(VectorType.RELYT, collection_name)) self.gen_index_struct_dict(VectorType.RELYT, collection_name))


config = current_app.config
return RelytVector( return RelytVector(
collection_name=collection_name, collection_name=collection_name,
config=RelytConfig( config=RelytConfig(
host=config.get('RELYT_HOST'),
port=config.get('RELYT_PORT'),
user=config.get('RELYT_USER'),
password=config.get('RELYT_PASSWORD'),
database=config.get('RELYT_DATABASE'),
host=dify_config.RELYT_HOST,
port=dify_config.RELYT_PORT,
user=dify_config.RELYT_USER,
password=dify_config.RELYT_PASSWORD,
database=dify_config.RELYT_DATABASE,
), ),
group_id=dataset.id group_id=dataset.id
) )

+ 9
- 10
api/core/rag/datasource/vdb/tencent/tencent_vector.py Ver arquivo

import json import json
from typing import Any, Optional from typing import Any, Optional


from flask import current_app
from pydantic import BaseModel from pydantic import BaseModel
from tcvectordb import VectorDBClient from tcvectordb import VectorDBClient
from tcvectordb.model import document, enum from tcvectordb.model import document, enum
from tcvectordb.model import index as vdb_index from tcvectordb.model import index as vdb_index
from tcvectordb.model.document import Filter from tcvectordb.model.document import Filter


from configs import dify_config
from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.entity.embedding import Embeddings
from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
dataset.index_struct = json.dumps( dataset.index_struct = json.dumps(
self.gen_index_struct_dict(VectorType.TENCENT, collection_name)) self.gen_index_struct_dict(VectorType.TENCENT, collection_name))


config = current_app.config
return TencentVector( return TencentVector(
collection_name=collection_name, collection_name=collection_name,
config=TencentConfig( config=TencentConfig(
url=config.get('TENCENT_VECTOR_DB_URL'),
api_key=config.get('TENCENT_VECTOR_DB_API_KEY'),
timeout=config.get('TENCENT_VECTOR_DB_TIMEOUT'),
username=config.get('TENCENT_VECTOR_DB_USERNAME'),
database=config.get('TENCENT_VECTOR_DB_DATABASE'),
shard=config.get('TENCENT_VECTOR_DB_SHARD'),
replicas=config.get('TENCENT_VECTOR_DB_REPLICAS'),
url=dify_config.TENCENT_VECTOR_DB_URL,
api_key=dify_config.TENCENT_VECTOR_DB_API_KEY,
timeout=dify_config.TENCENT_VECTOR_DB_TIMEOUT,
username=dify_config.TENCENT_VECTOR_DB_USERNAME,
database=dify_config.TENCENT_VECTOR_DB_DATABASE,
shard=dify_config.TENCENT_VECTOR_DB_SHARD,
replicas=dify_config.TENCENT_VECTOR_DB_REPLICAS,
) )
)
)

+ 10
- 11
api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py Ver arquivo

from typing import Any from typing import Any


import sqlalchemy import sqlalchemy
from flask import current_app
from pydantic import BaseModel, model_validator from pydantic import BaseModel, model_validator
from sqlalchemy import JSON, TEXT, Column, DateTime, String, Table, create_engine, insert from sqlalchemy import JSON, TEXT, Column, DateTime, String, Table, create_engine, insert
from sqlalchemy import text as sql_text from sqlalchemy import text as sql_text
from sqlalchemy.orm import Session, declarative_base from sqlalchemy.orm import Session, declarative_base


from configs import dify_config
from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.entity.embedding import Embeddings
from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
with Session(self._engine) as session: with Session(self._engine) as session:
select_statement = sql_text( select_statement = sql_text(
f"""SELECT meta, text, distance FROM ( f"""SELECT meta, text, distance FROM (
SELECT meta, text, {tidb_func}(vector, "{query_vector_str}") as distance
FROM {self._collection_name}
SELECT meta, text, {tidb_func}(vector, "{query_vector_str}") as distance
FROM {self._collection_name}
ORDER BY distance ORDER BY distance
LIMIT {top_k} LIMIT {top_k}
) t WHERE distance < {distance};""" ) t WHERE distance < {distance};"""
dataset.index_struct = json.dumps( dataset.index_struct = json.dumps(
self.gen_index_struct_dict(VectorType.TIDB_VECTOR, collection_name)) self.gen_index_struct_dict(VectorType.TIDB_VECTOR, collection_name))


config = current_app.config
return TiDBVector( return TiDBVector(
collection_name=collection_name, collection_name=collection_name,
config=TiDBVectorConfig( config=TiDBVectorConfig(
host=config.get('TIDB_VECTOR_HOST'),
port=config.get('TIDB_VECTOR_PORT'),
user=config.get('TIDB_VECTOR_USER'),
password=config.get('TIDB_VECTOR_PASSWORD'),
database=config.get('TIDB_VECTOR_DATABASE'),
program_name=config.get('APPLICATION_NAME'),
host=dify_config.TIDB_VECTOR_HOST,
port=dify_config.TIDB_VECTOR_PORT,
user=dify_config.TIDB_VECTOR_USER,
password=dify_config.TIDB_VECTOR_PASSWORD,
database=dify_config.TIDB_VECTOR_DATABASE,
program_name=dify_config.APPLICATION_NAME,
), ),
)
)

+ 2
- 4
api/core/rag/datasource/vdb/vector_factory.py Ver arquivo

from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any from typing import Any


from flask import current_app

from configs import dify_config
from core.embedding.cached_embedding import CacheEmbedding from core.embedding.cached_embedding import CacheEmbedding
from core.model_manager import ModelManager from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.model_entities import ModelType
self._vector_processor = self._init_vector() self._vector_processor = self._init_vector()


def _init_vector(self) -> BaseVector: def _init_vector(self) -> BaseVector:
config = current_app.config
vector_type = config.get('VECTOR_STORE')
vector_type = dify_config.VECTOR_STORE
if self._dataset.index_struct_dict: if self._dataset.index_struct_dict:
vector_type = self._dataset.index_struct_dict['type'] vector_type = self._dataset.index_struct_dict['type']



+ 4
- 4
api/core/rag/datasource/vdb/weaviate/weaviate_vector.py Ver arquivo



import requests import requests
import weaviate import weaviate
from flask import current_app
from pydantic import BaseModel, model_validator from pydantic import BaseModel, model_validator


from configs import dify_config
from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.entity.embedding import Embeddings
from core.rag.datasource.vdb.field import Field from core.rag.datasource.vdb.field import Field
from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_base import BaseVector
return WeaviateVector( return WeaviateVector(
collection_name=collection_name, collection_name=collection_name,
config=WeaviateConfig( config=WeaviateConfig(
endpoint=current_app.config.get('WEAVIATE_ENDPOINT'),
api_key=current_app.config.get('WEAVIATE_API_KEY'),
batch_size=int(current_app.config.get('WEAVIATE_BATCH_SIZE'))
endpoint=dify_config.WEAVIATE_ENDPOINT,
api_key=dify_config.WEAVIATE_API_KEY,
batch_size=dify_config.WEAVIATE_BATCH_SIZE
), ),
attributes=attributes attributes=attributes
) )

+ 4
- 4
api/core/rag/extractor/extract_processor.py Ver arquivo

from urllib.parse import unquote from urllib.parse import unquote


import requests import requests
from flask import current_app


from configs import dify_config
from core.rag.extractor.csv_extractor import CSVExtractor from core.rag.extractor.csv_extractor import CSVExtractor
from core.rag.extractor.entity.datasource_type import DatasourceType from core.rag.extractor.entity.datasource_type import DatasourceType
from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.extractor.entity.extract_setting import ExtractSetting
storage.download(upload_file.key, file_path) storage.download(upload_file.key, file_path)
input_file = Path(file_path) input_file = Path(file_path)
file_extension = input_file.suffix.lower() file_extension = input_file.suffix.lower()
etl_type = current_app.config['ETL_TYPE']
unstructured_api_url = current_app.config['UNSTRUCTURED_API_URL']
unstructured_api_key = current_app.config['UNSTRUCTURED_API_KEY']
etl_type = dify_config.ETL_TYPE
unstructured_api_url = dify_config.UNSTRUCTURED_API_URL
unstructured_api_key = dify_config.UNSTRUCTURED_API_KEY
if etl_type == 'Unstructured': if etl_type == 'Unstructured':
if file_extension == '.xlsx' or file_extension == '.xls': if file_extension == '.xlsx' or file_extension == '.xls':
extractor = ExcelExtractor(file_path) extractor = ExcelExtractor(file_path)

+ 2
- 2
api/core/rag/extractor/notion_extractor.py Ver arquivo

from typing import Any, Optional from typing import Any, Optional


import requests import requests
from flask import current_app


from configs import dify_config
from core.rag.extractor.extractor_base import BaseExtractor from core.rag.extractor.extractor_base import BaseExtractor
from core.rag.models.document import Document from core.rag.models.document import Document
from extensions.ext_database import db from extensions.ext_database import db
self._notion_access_token = self._get_access_token(tenant_id, self._notion_access_token = self._get_access_token(tenant_id,
self._notion_workspace_id) self._notion_workspace_id)
if not self._notion_access_token: if not self._notion_access_token:
integration_token = current_app.config.get('NOTION_INTEGRATION_TOKEN')
integration_token = dify_config.NOTION_INTEGRATION_TOKEN
if integration_token is None: if integration_token is None:
raise ValueError( raise ValueError(
"Must specify `integration_token` or set environment " "Must specify `integration_token` or set environment "

+ 3
- 4
api/core/rag/extractor/word_extractor.py Ver arquivo



import requests import requests
from docx import Document as DocxDocument from docx import Document as DocxDocument
from flask import current_app


from configs import dify_config
from core.rag.extractor.extractor_base import BaseExtractor from core.rag.extractor.extractor_base import BaseExtractor
from core.rag.models.document import Document from core.rag.models.document import Document
from extensions.ext_database import db from extensions.ext_database import db


storage.save(file_key, rel.target_part.blob) storage.save(file_key, rel.target_part.blob)
# save file to db # save file to db
config = current_app.config
upload_file = UploadFile( upload_file = UploadFile(
tenant_id=self.tenant_id, tenant_id=self.tenant_id,
storage_type=config['STORAGE_TYPE'],
storage_type=dify_config.STORAGE_TYPE,
key=file_key, key=file_key,
name=file_key, name=file_key,
size=0, size=0,


db.session.add(upload_file) db.session.add(upload_file)
db.session.commit() db.session.commit()
image_map[rel.target_part] = f"![image]({current_app.config.get('CONSOLE_API_URL')}/files/{upload_file.id}/image-preview)"
image_map[rel.target_part] = f"![image]({dify_config.CONSOLE_API_URL}/files/{upload_file.id}/image-preview)"


return image_map return image_map



+ 2
- 3
api/core/rag/index_processor/index_processor_base.py Ver arquivo

from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Optional from typing import Optional


from flask import current_app

from configs import dify_config
from core.model_manager import ModelInstance from core.model_manager import ModelInstance
from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.models.document import Document from core.rag.models.document import Document
# The user-defined segmentation rule # The user-defined segmentation rule
rules = processing_rule['rules'] rules = processing_rule['rules']
segmentation = rules["segmentation"] segmentation = rules["segmentation"]
max_segmentation_tokens_length = int(current_app.config['INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH'])
max_segmentation_tokens_length = dify_config.INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH
if segmentation["max_tokens"] < 50 or segmentation["max_tokens"] > max_segmentation_tokens_length: if segmentation["max_tokens"] < 50 or segmentation["max_tokens"] > max_segmentation_tokens_length:
raise ValueError(f"Custom segment length should be between 50 and {max_segmentation_tokens_length}.") raise ValueError(f"Custom segment length should be between 50 and {max_segmentation_tokens_length}.")



Carregando…
Cancelar
Salvar