Преглед изворни кода

feat: add APIs for setting default datasource provider and updating provider name

tags/2.0.0-beta.1
Harry пре 3 месеци
родитељ
комит
af94602d37

+ 53
- 0
api/controllers/console/datasets/rag_pipeline/datasource_auth.py Прегледај датотеку

@@ -205,6 +205,7 @@ class DatasourceAuthListApi(Resource):
)
return {"result": jsonable_encoder(datasources)}, 200


class DatasourceAuthOauthCustomClient(Resource):
@setup_required
@login_required
@@ -227,6 +228,48 @@ class DatasourceAuthOauthCustomClient(Resource):
)
return {"result": "success"}, 200


class DatasourceAuthDefaultApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, provider_id: str):
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
datasource_provider_id = DatasourceProviderID(provider_id)
datasource_provider_service = DatasourceProviderService()
datasource_provider_service.set_default_datasource_provider(
tenant_id=current_user.current_tenant_id,
datasource_provider_id=datasource_provider_id,
credential_id=args["credential_id"],
)
return {"result": "success"}, 200

class DatasourceUpdateProviderNameApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, provider_id: str):
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("name", type=str, required=True, nullable=False, location="json")
parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
datasource_provider_id = DatasourceProviderID(provider_id)
datasource_provider_service = DatasourceProviderService()
datasource_provider_service.update_datasource_provider_name(
tenant_id=current_user.current_tenant_id,
datasource_provider_id=datasource_provider_id,
name=args["name"],
credential_id=args["credential_id"],
)
return {"result": "success"}, 200


api.add_resource(
DatasourcePluginOAuthAuthorizationUrl,
"/oauth/plugin/<path:provider_id>/datasource/get-authorization-url",
@@ -254,3 +297,13 @@ api.add_resource(
DatasourceAuthOauthCustomClient,
"/auth/plugin/datasource/<path:provider_id>/custom-client",
)

api.add_resource(
DatasourceAuthDefaultApi,
"/auth/plugin/datasource/<path:provider_id>/default",
)

api.add_resource(
DatasourceUpdateProviderNameApi,
"/auth/plugin/datasource/<path:provider_id>/update-name",
)

+ 3
- 3
api/core/workflow/nodes/datasource/datasource_node.py Прегледај датотеку

@@ -127,13 +127,13 @@ class DatasourceNode(BaseNode):
case DatasourceProviderType.ONLINE_DOCUMENT:
datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime)
datasource_provider_service = DatasourceProviderService()
credentials = datasource_provider_service.get_real_datasource_credentials(
credentials = datasource_provider_service.get_default_credentials(
tenant_id=self.tenant_id,
provider=node_data.provider_name,
plugin_id=node_data.plugin_id,
)
if credentials:
datasource_runtime.runtime.credentials = credentials[0].get("credentials")
datasource_runtime.runtime.credentials = credentials
online_document_result: Generator[DatasourceMessage, None, None] = (
datasource_runtime.get_online_document_page_content(
user_id=self.user_id,
@@ -159,7 +159,7 @@ class DatasourceNode(BaseNode):
plugin_id=node_data.plugin_id,
)
if credentials:
datasource_runtime.runtime.credentials = credentials[0].get("credentials")
datasource_runtime.runtime.credentials = credentials
online_drive_result: Generator[DatasourceMessage, None, None] = (
datasource_runtime.online_drive_download_file(
user_id=self.user_id,

+ 33
- 0
api/migrations/versions/2025_07_21_1523-74e5f667f4b7_add_pipeline_info_15.py Прегледај датотеку

@@ -0,0 +1,33 @@
"""add_pipeline_info_15

Revision ID: 74e5f667f4b7
Revises: d3c68680d3ba
Create Date: 2025-07-21 15:23:20.825747

"""
from alembic import op
import models as models
import sqlalchemy as sa


# revision identifiers, used by Alembic.
revision = '74e5f667f4b7'
down_revision = 'd3c68680d3ba'
branch_labels = None
depends_on = None


def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('datasource_providers', schema=None) as batch_op:
batch_op.add_column(sa.Column('is_default', sa.Boolean(), server_default=sa.text('false'), nullable=False))

# ### end Alembic commands ###


def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('datasource_providers', schema=None) as batch_op:
batch_op.drop_column('is_default')

# ### end Alembic commands ###

+ 1
- 0
api/models/oauth.py Прегледај датотеку

@@ -36,6 +36,7 @@ class DatasourceProvider(Base):
auth_type: Mapped[str] = db.Column(db.String(255), nullable=False)
encrypted_credentials: Mapped[dict] = db.Column(JSONB, nullable=False)
avatar_url: Mapped[str] = db.Column(db.String(255), nullable=True, default="default")
is_default: Mapped[bool] = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))

created_at: Mapped[datetime] = db.Column(db.DateTime, nullable=False, default=datetime.now)
updated_at: Mapped[datetime] = db.Column(db.DateTime, nullable=False, default=datetime.now)

+ 122
- 13
api/services/datasource_provider_service.py Прегледај датотеку

@@ -29,6 +29,96 @@ class DatasourceProviderService:
def __init__(self) -> None:
self.provider_manager = PluginDatasourceManager()

def get_default_credentials(self, tenant_id: str, provider: str, plugin_id: str) -> dict[str, Any]:
"""
get default credentials
"""
with Session(db.engine) as session:
datasource_provider = (
session.query(DatasourceProvider)
.filter_by(tenant_id=tenant_id, provider=provider, plugin_id=plugin_id)
.order_by(DatasourceProvider.is_default.desc(), DatasourceProvider.created_at.asc())
.first()
)
if not datasource_provider:
return {}
return datasource_provider.encrypted_credentials

def update_datasource_provider_name(
self, tenant_id: str, datasource_provider_id: DatasourceProviderID, name: str, credential_id: str
):
"""
update datasource provider name
"""
with Session(db.engine) as session:
target_provider = (
session.query(DatasourceProvider)
.filter_by(
tenant_id=tenant_id,
id=credential_id,
provider=datasource_provider_id.provider_name,
plugin_id=datasource_provider_id.plugin_id,
)
.first()
)
if target_provider is None:
raise ValueError("provider not found")

if target_provider.name == name:
return

# check name is exist
if (
session.query(DatasourceProvider)
.filter_by(
tenant_id=tenant_id,
name=name,
provider=datasource_provider_id.provider_name,
plugin_id=datasource_provider_id.plugin_id,
)
.count()
> 0
):
raise ValueError("name is already exists")

target_provider.name = name
session.commit()
return

def set_default_datasource_provider(
self, tenant_id: str, datasource_provider_id: DatasourceProviderID, credential_id: str
):
"""
set default datasource provider
"""
with Session(db.engine) as session:
# get provider
target_provider = (
session.query(DatasourceProvider)
.filter_by(
tenant_id=tenant_id,
id=credential_id,
provider=datasource_provider_id.provider_name,
plugin_id=datasource_provider_id.plugin_id,
)
.first()
)
if target_provider is None:
raise ValueError("provider not found")

# clear default provider
session.query(DatasourceProvider).filter_by(
tenant_id=tenant_id,
provider=target_provider.provider,
plugin_id=target_provider.plugin_id,
is_default=True,
).update({"is_default": False})

# set new default provider
target_provider.is_default = True
session.commit()
return {"result": "success"}

def setup_oauth_custom_client_params(
self,
tenant_id: str,
@@ -41,10 +131,6 @@ class DatasourceProviderService:
"""
if client_params is None and enabled is None:
return
provider_controller = PluginDatasourceManager()
datasource_provider = provider_controller.fetch_datasource_provider(
tenant_id=tenant_id, provider_id=str(datasource_provider_id)
)
with Session(db.engine) as session:
tenant_oauth_client_params = (
session.query(DatasourceOauthTenantParamConfig)
@@ -252,7 +338,7 @@ class DatasourceProviderService:
)

provider_credential_secret_variables = self.extract_secret_variables(
tenant_id=tenant_id, provider_id=f"{provider_id}"
tenant_id=tenant_id, provider_id=f"{provider_id}", credential_type=credential_type.value
)
for key, value in credentials.items():
if key in provider_credential_secret_variables:
@@ -310,7 +396,7 @@ class DatasourceProviderService:
)
if credential_valid:
provider_credential_secret_variables = self.extract_secret_variables(
tenant_id=tenant_id, provider_id=f"{provider_id}"
tenant_id=tenant_id, provider_id=f"{provider_id}", credential_type=CredentialType.API_KEY.value
)
for key, value in credentials.items():
if key in provider_credential_secret_variables:
@@ -329,7 +415,7 @@ class DatasourceProviderService:
else:
raise CredentialsValidateFailedError()

def extract_secret_variables(self, tenant_id: str, provider_id: str) -> list[str]:
def extract_secret_variables(self, tenant_id: str, provider_id: str, credential_type: str) -> list[str]:
"""
Extract secret input form variables.

@@ -339,7 +425,16 @@ class DatasourceProviderService:
datasource_provider = self.provider_manager.fetch_datasource_provider(
tenant_id=tenant_id, provider_id=provider_id
)
credential_form_schemas = datasource_provider.declaration.credentials_schema
credential_form_schemas = []
if credential_type == "api_key":
credential_form_schemas = datasource_provider.declaration.credentials_schema
elif credential_type == "oauth2":
if not datasource_provider.declaration.oauth_schema:
raise ValueError("Datasource provider oauth schema not found")
credential_form_schemas = datasource_provider.declaration.oauth_schema.credentials_schema
else:
raise ValueError(f"Invalid credential type: {credential_type}")

secret_input_form_variables = []
for credential_form_schema in credential_form_schemas:
if credential_form_schema.type.value == FormType.SECRET_INPUT.value:
@@ -368,11 +463,20 @@ class DatasourceProviderService:
if not datasource_providers:
return []
copy_credentials_list = []
default_provider = (
db.session.query(DatasourceProvider.id)
.filter_by(tenant_id=tenant_id, provider=provider, plugin_id=plugin_id)
.order_by(DatasourceProvider.is_default.desc(), DatasourceProvider.created_at.asc())
.first()
)
default_provider_id = default_provider.id if default_provider else None
for datasource_provider in datasource_providers:
encrypted_credentials = datasource_provider.encrypted_credentials
# Get provider credential secret variables
credential_secret_variables = self.extract_secret_variables(
tenant_id=tenant_id, provider_id=f"{plugin_id}/{provider}"
tenant_id=tenant_id,
provider_id=f"{plugin_id}/{provider}",
credential_type=datasource_provider.auth_type,
)

# Obfuscate provider credentials
@@ -387,6 +491,7 @@ class DatasourceProviderService:
"name": datasource_provider.name,
"avatar_url": datasource_provider.avatar_url,
"id": datasource_provider.id,
"is_default": default_provider_id and datasource_provider.id == default_provider_id,
}
)

@@ -469,7 +574,9 @@ class DatasourceProviderService:
encrypted_credentials = datasource_provider.encrypted_credentials
# Get provider credential secret variables
credential_secret_variables = self.extract_secret_variables(
tenant_id=tenant_id, provider_id=f"{plugin_id}/{provider}"
tenant_id=tenant_id,
provider_id=f"{plugin_id}/{provider}",
credential_type=datasource_provider.auth_type,
)

# Obfuscate provider credentials
@@ -507,12 +614,14 @@ class DatasourceProviderService:
.first()
)

provider_credential_secret_variables = self.extract_secret_variables(
tenant_id=tenant_id, provider_id=f"{plugin_id}/{provider}"
)
if not datasource_provider:
raise ValueError("Datasource provider not found")
else:
provider_credential_secret_variables = self.extract_secret_variables(
tenant_id=tenant_id,
provider_id=f"{plugin_id}/{provider}",
credential_type=datasource_provider.auth_type,
)
original_credentials = datasource_provider.encrypted_credentials
for key, value in credentials.items():
if key in provider_credential_secret_variables:

Loading…
Откажи
Сачувај