Pārlūkot izejas kodu

refactor: Use typed SQLAlchemy base model and fix type errors (#19980)

Signed-off-by: -LAN- <laipz8200@outlook.com>
tags/1.4.1
-LAN- pirms 5 mēnešiem
vecāks
revīzija
3196dc2d61
Revīzijas autora e-pasta adrese nav piesaistīta nevienam kontam

+ 6
- 6
api/controllers/console/auth/login.py Parādīt failu

except AccountRegisterError as are: except AccountRegisterError as are:
raise AccountInFreezeError() raise AccountInFreezeError()
if account: if account:
tenant = TenantService.get_join_tenants(account)
if not tenant:
tenants = TenantService.get_join_tenants(account)
if not tenants:
workspaces = FeatureService.get_system_features().license.workspaces workspaces = FeatureService.get_system_features().license.workspaces
if not workspaces.is_available(): if not workspaces.is_available():
raise WorkspacesLimitExceeded() raise WorkspacesLimitExceeded()
if not FeatureService.get_system_features().is_allow_create_workspace: if not FeatureService.get_system_features().is_allow_create_workspace:
raise NotAllowedCreateWorkspace() raise NotAllowedCreateWorkspace()
else: else:
tenant = TenantService.create_tenant(f"{account.name}'s Workspace")
TenantService.create_tenant_member(tenant, account, role="owner")
account.current_tenant = tenant
tenant_was_created.send(tenant)
new_tenant = TenantService.create_tenant(f"{account.name}'s Workspace")
TenantService.create_tenant_member(new_tenant, account, role="owner")
account.current_tenant = new_tenant
tenant_was_created.send(new_tenant)


if account is None: if account is None:
try: try:

+ 6
- 6
api/controllers/console/auth/oauth.py Parādīt failu

account = _get_account_by_openid_or_email(provider, user_info) account = _get_account_by_openid_or_email(provider, user_info)


if account: if account:
tenant = TenantService.get_join_tenants(account)
if not tenant:
tenants = TenantService.get_join_tenants(account)
if not tenants:
if not FeatureService.get_system_features().is_allow_create_workspace: if not FeatureService.get_system_features().is_allow_create_workspace:
raise WorkSpaceNotAllowedCreateError() raise WorkSpaceNotAllowedCreateError()
else: else:
tenant = TenantService.create_tenant(f"{account.name}'s Workspace")
TenantService.create_tenant_member(tenant, account, role="owner")
account.current_tenant = tenant
tenant_was_created.send(tenant)
new_tenant = TenantService.create_tenant(f"{account.name}'s Workspace")
TenantService.create_tenant_member(new_tenant, account, role="owner")
account.current_tenant = new_tenant
tenant_was_created.send(new_tenant)


if not account: if not account:
if not FeatureService.get_system_features().is_allow_register: if not FeatureService.get_system_features().is_allow_register:

+ 16
- 3
api/controllers/console/datasets/datasets.py Parādīt failu

.filter(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment") .filter(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment")
.count() .count()
) )
document.completed_segments = completed_segments
document.total_segments = total_segments
documents_status.append(marshal(document, document_status_fields))
# Create a dictionary with document attributes and additional fields
document_dict = {
"id": document.id,
"indexing_status": document.indexing_status,
"processing_started_at": document.processing_started_at,
"parsing_completed_at": document.parsing_completed_at,
"cleaning_completed_at": document.cleaning_completed_at,
"splitting_completed_at": document.splitting_completed_at,
"completed_at": document.completed_at,
"paused_at": document.paused_at,
"error": document.error,
"stopped_at": document.stopped_at,
"completed_segments": completed_segments,
"total_segments": total_segments,
}
documents_status.append(marshal(document_dict, document_status_fields))
data = {"data": documents_status} data = {"data": documents_status}
return data return data



+ 32
- 10
api/controllers/console/datasets/datasets_document.py Parādīt failu

.filter(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment") .filter(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment")
.count() .count()
) )
document.completed_segments = completed_segments
document.total_segments = total_segments
if document.is_paused:
document.indexing_status = "paused"
documents_status.append(marshal(document, document_status_fields))
# Create a dictionary with document attributes and additional fields
document_dict = {
"id": document.id,
"indexing_status": "paused" if document.is_paused else document.indexing_status,
"processing_started_at": document.processing_started_at,
"parsing_completed_at": document.parsing_completed_at,
"cleaning_completed_at": document.cleaning_completed_at,
"splitting_completed_at": document.splitting_completed_at,
"completed_at": document.completed_at,
"paused_at": document.paused_at,
"error": document.error,
"stopped_at": document.stopped_at,
"completed_segments": completed_segments,
"total_segments": total_segments,
}
documents_status.append(marshal(document_dict, document_status_fields))
data = {"data": documents_status} data = {"data": documents_status}
return data return data


.count() .count()
) )


document.completed_segments = completed_segments
document.total_segments = total_segments
if document.is_paused:
document.indexing_status = "paused"
return marshal(document, document_status_fields)
# Create a dictionary with document attributes and additional fields
document_dict = {
"id": document.id,
"indexing_status": "paused" if document.is_paused else document.indexing_status,
"processing_started_at": document.processing_started_at,
"parsing_completed_at": document.parsing_completed_at,
"cleaning_completed_at": document.cleaning_completed_at,
"splitting_completed_at": document.splitting_completed_at,
"completed_at": document.completed_at,
"paused_at": document.paused_at,
"error": document.error,
"stopped_at": document.stopped_at,
"completed_segments": completed_segments,
"total_segments": total_segments,
}
return marshal(document_dict, document_status_fields)




class DocumentDetailApi(DocumentResource): class DocumentDetailApi(DocumentResource):

+ 15
- 7
api/controllers/console/workspace/workspace.py Parādīt failu

@account_initialization_required @account_initialization_required
def get(self): def get(self):
tenants = TenantService.get_join_tenants(current_user) tenants = TenantService.get_join_tenants(current_user)
tenant_dicts = []


for tenant in tenants: for tenant in tenants:
features = FeatureService.get_features(tenant.id) features = FeatureService.get_features(tenant.id)
if features.billing.enabled:
tenant.plan = features.billing.subscription.plan
else:
tenant.plan = "sandbox"
if tenant.id == current_user.current_tenant_id:
tenant.current = True # Set current=True for current tenant
return {"workspaces": marshal(tenants, tenants_fields)}, 200

# Create a dictionary with tenant attributes
tenant_dict = {
"id": tenant.id,
"name": tenant.name,
"status": tenant.status,
"created_at": tenant.created_at,
"plan": features.billing.subscription.plan if features.billing.enabled else "sandbox",
"current": tenant.id == current_user.current_tenant_id,
}

tenant_dicts.append(tenant_dict)

return {"workspaces": marshal(tenant_dicts, tenants_fields)}, 200




class WorkspaceListApi(Resource): class WorkspaceListApi(Resource):

+ 18
- 3
api/controllers/files/upload.py Parādīt failu



extension = guess_extension(tool_file.mimetype) or ".bin" extension = guess_extension(tool_file.mimetype) or ".bin"
preview_url = ToolFileManager.sign_file(tool_file_id=tool_file.id, extension=extension) preview_url = ToolFileManager.sign_file(tool_file_id=tool_file.id, extension=extension)
tool_file.mime_type = mimetype
tool_file.extension = extension
tool_file.preview_url = preview_url

# Create a dictionary with all the necessary attributes
result = {
"id": tool_file.id,
"user_id": tool_file.user_id,
"tenant_id": tool_file.tenant_id,
"conversation_id": tool_file.conversation_id,
"file_key": tool_file.file_key,
"mimetype": tool_file.mimetype,
"original_url": tool_file.original_url,
"name": tool_file.name,
"size": tool_file.size,
"mime_type": mimetype,
"extension": extension,
"preview_url": preview_url,
}

return result, 201
except services.errors.file.FileTooLargeError as file_too_large_error: except services.errors.file.FileTooLargeError as file_too_large_error:
raise FileTooLargeError(file_too_large_error.description) raise FileTooLargeError(file_too_large_error.description)
except services.errors.file.UnsupportedFileTypeError: except services.errors.file.UnsupportedFileTypeError:

+ 16
- 5
api/controllers/service_api/dataset/document.py Parādīt failu

.filter(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment") .filter(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment")
.count() .count()
) )
document.completed_segments = completed_segments
document.total_segments = total_segments
if document.is_paused:
document.indexing_status = "paused"
documents_status.append(marshal(document, document_status_fields))
# Create a dictionary with document attributes and additional fields
document_dict = {
"id": document.id,
"indexing_status": "paused" if document.is_paused else document.indexing_status,
"processing_started_at": document.processing_started_at,
"parsing_completed_at": document.parsing_completed_at,
"cleaning_completed_at": document.cleaning_completed_at,
"splitting_completed_at": document.splitting_completed_at,
"completed_at": document.completed_at,
"paused_at": document.paused_at,
"error": document.error,
"stopped_at": document.stopped_at,
"completed_segments": completed_segments,
"total_segments": total_segments,
}
documents_status.append(marshal(document_dict, document_status_fields))
data = {"data": documents_status} data = {"data": documents_status}
return data return data



+ 23
- 1
api/core/rag/datasource/retrieval_service.py Parādīt failu

record["child_chunks"] = segment_child_map[record["segment"].id].get("child_chunks") # type: ignore record["child_chunks"] = segment_child_map[record["segment"].id].get("child_chunks") # type: ignore
record["score"] = segment_child_map[record["segment"].id]["max_score"] record["score"] = segment_child_map[record["segment"].id]["max_score"]


return [RetrievalSegments(**record) for record in records]
result = []
for record in records:
# Extract segment
segment = record["segment"]

# Extract child_chunks, ensuring it's a list or None
child_chunks = record.get("child_chunks")
if not isinstance(child_chunks, list):
child_chunks = None

# Extract score, ensuring it's a float or None
score_value = record.get("score")
score = (
float(score_value)
if score_value is not None and isinstance(score_value, int | float | str)
else None
)

# Create RetrievalSegments object
retrieval_segment = RetrievalSegments(segment=segment, child_chunks=child_chunks, score=score)
result.append(retrieval_segment)

return result
except Exception as e: except Exception as e:
db.session.rollback() db.session.rollback()
raise e raise e

+ 3
- 3
api/core/tools/tool_manager.py Parādīt failu

yield provider yield provider


except Exception: except Exception:
logger.exception(f"load builtin provider {provider}")
logger.exception(f"load builtin provider {provider_path}")
continue continue
# set builtin providers loaded # set builtin providers loaded
cls._builtin_providers_loaded = True cls._builtin_providers_loaded = True
) )


workflow_provider_controllers: list[WorkflowToolProviderController] = [] workflow_provider_controllers: list[WorkflowToolProviderController] = []
for provider in workflow_providers:
for workflow_provider in workflow_providers:
try: try:
workflow_provider_controllers.append( workflow_provider_controllers.append(
ToolTransformService.workflow_provider_to_controller(db_provider=provider)
ToolTransformService.workflow_provider_to_controller(db_provider=workflow_provider)
) )
except Exception: except Exception:
# app has been deleted # app has been deleted

+ 4
- 16
api/core/tools/workflow_as_tool/tool.py Parādīt failu

import json import json
import logging import logging
from collections.abc import Generator from collections.abc import Generator
from typing import Any, Optional, Union, cast
from typing import Any, Optional, cast

from flask_login import current_user


from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod
from core.tools.__base.tool import Tool from core.tools.__base.tool import Tool
result = generator.generate( result = generator.generate(
app_model=app, app_model=app,
workflow=workflow, workflow=workflow,
user=self._get_user(user_id),
user=cast("Account | EndUser", current_user),
args={"inputs": tool_parameters, "files": files}, args={"inputs": tool_parameters, "files": files},
invoke_from=self.runtime.invoke_from, invoke_from=self.runtime.invoke_from,
streaming=False, streaming=False,
yield self.create_text_message(json.dumps(outputs, ensure_ascii=False)) yield self.create_text_message(json.dumps(outputs, ensure_ascii=False))
yield self.create_json_message(outputs) yield self.create_json_message(outputs)


def _get_user(self, user_id: str) -> Union[EndUser, Account]:
"""
get the user by user id
"""

user = db.session.query(EndUser).filter(EndUser.id == user_id).first()
if not user:
user = db.session.query(Account).filter(Account.id == user_id).first()

if not user:
raise ValueError("user not found")

return user

def fork_tool_runtime(self, runtime: ToolRuntime) -> "WorkflowTool": def fork_tool_runtime(self, runtime: ToolRuntime) -> "WorkflowTool":
""" """
fork a new tool with metadata fork a new tool with metadata

+ 33
- 16
api/extensions/ext_login.py Parādīt failu

import flask_login # type: ignore import flask_login # type: ignore
from flask import Response, request from flask import Response, request
from flask_login import user_loaded_from_request, user_logged_in from flask_login import user_loaded_from_request, user_logged_in
from werkzeug.exceptions import Unauthorized
from werkzeug.exceptions import NotFound, Unauthorized


import contexts import contexts
from dify_app import DifyApp from dify_app import DifyApp
from extensions.ext_database import db
from libs.passport import PassportService from libs.passport import PassportService
from models.account import Account
from models.model import EndUser
from services.account_service import AccountService from services.account_service import AccountService


login_manager = flask_login.LoginManager() login_manager = flask_login.LoginManager()
@login_manager.request_loader @login_manager.request_loader
def load_user_from_request(request_from_flask_login): def load_user_from_request(request_from_flask_login):
"""Load user based on the request.""" """Load user based on the request."""
if request.blueprint not in {"console", "inner_api"}:
return None
# Check if the user_id contains a dot, indicating the old format
auth_header = request.headers.get("Authorization", "") auth_header = request.headers.get("Authorization", "")
if not auth_header:
auth_token = request.args.get("_token")
if not auth_token:
raise Unauthorized("Invalid Authorization token.")
else:
auth_token: str | None = None
if auth_header:
if " " not in auth_header: if " " not in auth_header:
raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.") raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
auth_scheme, auth_token = auth_header.split(None, 1)
auth_scheme, auth_token = auth_header.split(maxsplit=1)
auth_scheme = auth_scheme.lower() auth_scheme = auth_scheme.lower()
if auth_scheme != "bearer": if auth_scheme != "bearer":
raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.") raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
else:
auth_token = request.args.get("_token")


decoded = PassportService().verify(auth_token)
user_id = decoded.get("user_id")
if request.blueprint in {"console", "inner_api"}:
if not auth_token:
raise Unauthorized("Invalid Authorization token.")
decoded = PassportService().verify(auth_token)
user_id = decoded.get("user_id")
if not user_id:
raise Unauthorized("Invalid Authorization token.")


logged_in_account = AccountService.load_logged_in_account(account_id=user_id)
return logged_in_account
logged_in_account = AccountService.load_logged_in_account(account_id=user_id)
return logged_in_account
elif request.blueprint == "web":
decoded = PassportService().verify(auth_token)
end_user_id = decoded.get("end_user_id")
if not end_user_id:
raise Unauthorized("Invalid Authorization token.")
end_user = db.session.query(EndUser).filter(EndUser.id == decoded["end_user_id"]).first()
if not end_user:
raise NotFound("End user not found.")
return end_user




@user_logged_in.connect @user_logged_in.connect
@user_loaded_from_request.connect @user_loaded_from_request.connect
def on_user_logged_in(_sender, user): def on_user_logged_in(_sender, user):
"""Called when a user logged in."""
if user:
"""Called when a user logged in.

Note: AccountService.load_logged_in_account will populate user.current_tenant_id
through the load_user method, which calls account.set_tenant_id().
"""
if user and isinstance(user, Account) and user.current_tenant_id:
contexts.tenant_id.set(user.current_tenant_id) contexts.tenant_id.set(user.current_tenant_id)





+ 80
- 77
api/models/account.py Parādīt failu

import enum import enum
import json import json
from typing import cast
from typing import Optional, cast


from flask_login import UserMixin # type: ignore from flask_login import UserMixin # type: ignore
from sqlalchemy import func from sqlalchemy import func
from sqlalchemy.orm import Mapped, mapped_column
from sqlalchemy.orm import Mapped, mapped_column, reconstructor


from models.base import Base from models.base import Base


from .types import StringUUID from .types import StringUUID




class TenantAccountRole(enum.StrEnum):
OWNER = "owner"
ADMIN = "admin"
EDITOR = "editor"
NORMAL = "normal"
DATASET_OPERATOR = "dataset_operator"

@staticmethod
def is_valid_role(role: str) -> bool:
if not role:
return False
return role in {
TenantAccountRole.OWNER,
TenantAccountRole.ADMIN,
TenantAccountRole.EDITOR,
TenantAccountRole.NORMAL,
TenantAccountRole.DATASET_OPERATOR,
}

@staticmethod
def is_privileged_role(role: Optional["TenantAccountRole"]) -> bool:
if not role:
return False
return role in {TenantAccountRole.OWNER, TenantAccountRole.ADMIN}

@staticmethod
def is_admin_role(role: Optional["TenantAccountRole"]) -> bool:
if not role:
return False
return role == TenantAccountRole.ADMIN

@staticmethod
def is_non_owner_role(role: Optional["TenantAccountRole"]) -> bool:
if not role:
return False
return role in {
TenantAccountRole.ADMIN,
TenantAccountRole.EDITOR,
TenantAccountRole.NORMAL,
TenantAccountRole.DATASET_OPERATOR,
}

@staticmethod
def is_editing_role(role: Optional["TenantAccountRole"]) -> bool:
if not role:
return False
return role in {TenantAccountRole.OWNER, TenantAccountRole.ADMIN, TenantAccountRole.EDITOR}

@staticmethod
def is_dataset_edit_role(role: Optional["TenantAccountRole"]) -> bool:
if not role:
return False
return role in {
TenantAccountRole.OWNER,
TenantAccountRole.ADMIN,
TenantAccountRole.EDITOR,
TenantAccountRole.DATASET_OPERATOR,
}


class AccountStatus(enum.StrEnum): class AccountStatus(enum.StrEnum):
PENDING = "pending" PENDING = "pending"
UNINITIALIZED = "uninitialized" UNINITIALIZED = "uninitialized"
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())


@reconstructor
def init_on_load(self):
self.role: Optional[TenantAccountRole] = None
self._current_tenant: Optional[Tenant] = None

@property @property
def is_password_set(self): def is_password_set(self):
return self.password is not None return self.password is not None


@property @property
def current_tenant(self): def current_tenant(self):
return self._current_tenant # type: ignore
return self._current_tenant


@current_tenant.setter @current_tenant.setter
def current_tenant(self, value: "Tenant"):
tenant = value
def current_tenant(self, tenant: "Tenant"):
ta = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=self.id).first() ta = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=self.id).first()
if ta: if ta:
tenant.current_role = ta.role
else:
tenant = None # type: ignore

self._current_tenant = tenant
self.role = TenantAccountRole(ta.role)
self._current_tenant = tenant
return
self._current_tenant = None


@property @property
def current_tenant_id(self) -> str | None: def current_tenant_id(self) -> str | None:
return return


tenant, join = tenant_account_join tenant, join = tenant_account_join
tenant.current_role = join.role
self.role = join.role
self._current_tenant = tenant self._current_tenant = tenant


@property @property
def current_role(self): def current_role(self):
return self._current_tenant.current_role
return self.role


def get_status(self) -> AccountStatus: def get_status(self) -> AccountStatus:
status_str = self.status status_str = self.status
# check current_user.current_tenant.current_role in ['admin', 'owner'] # check current_user.current_tenant.current_role in ['admin', 'owner']
@property @property
def is_admin_or_owner(self): def is_admin_or_owner(self):
return TenantAccountRole.is_privileged_role(self._current_tenant.current_role)
return TenantAccountRole.is_privileged_role(self.role)


@property @property
def is_admin(self): def is_admin(self):
return TenantAccountRole.is_admin_role(self._current_tenant.current_role)
return TenantAccountRole.is_admin_role(self.role)


@property @property
def is_editor(self): def is_editor(self):
return TenantAccountRole.is_editing_role(self._current_tenant.current_role)
return TenantAccountRole.is_editing_role(self.role)


@property @property
def is_dataset_editor(self): def is_dataset_editor(self):
return TenantAccountRole.is_dataset_edit_role(self._current_tenant.current_role)
return TenantAccountRole.is_dataset_edit_role(self.role)


@property @property
def is_dataset_operator(self): def is_dataset_operator(self):
return self._current_tenant.current_role == TenantAccountRole.DATASET_OPERATOR
return self.role == TenantAccountRole.DATASET_OPERATOR




class TenantStatus(enum.StrEnum): class TenantStatus(enum.StrEnum):
ARCHIVE = "archive" ARCHIVE = "archive"




class TenantAccountRole(enum.StrEnum):
OWNER = "owner"
ADMIN = "admin"
EDITOR = "editor"
NORMAL = "normal"
DATASET_OPERATOR = "dataset_operator"

@staticmethod
def is_valid_role(role: str) -> bool:
if not role:
return False
return role in {
TenantAccountRole.OWNER,
TenantAccountRole.ADMIN,
TenantAccountRole.EDITOR,
TenantAccountRole.NORMAL,
TenantAccountRole.DATASET_OPERATOR,
}

@staticmethod
def is_privileged_role(role: str) -> bool:
if not role:
return False
return role in {TenantAccountRole.OWNER, TenantAccountRole.ADMIN}

@staticmethod
def is_admin_role(role: str) -> bool:
if not role:
return False
return role == TenantAccountRole.ADMIN

@staticmethod
def is_non_owner_role(role: str) -> bool:
if not role:
return False
return role in {
TenantAccountRole.ADMIN,
TenantAccountRole.EDITOR,
TenantAccountRole.NORMAL,
TenantAccountRole.DATASET_OPERATOR,
}

@staticmethod
def is_editing_role(role: str) -> bool:
if not role:
return False
return role in {TenantAccountRole.OWNER, TenantAccountRole.ADMIN, TenantAccountRole.EDITOR}

@staticmethod
def is_dataset_edit_role(role: str) -> bool:
if not role:
return False
return role in {
TenantAccountRole.OWNER,
TenantAccountRole.ADMIN,
TenantAccountRole.EDITOR,
TenantAccountRole.DATASET_OPERATOR,
}


class Tenant(Base): class Tenant(Base):
__tablename__ = "tenants" __tablename__ = "tenants"
__table_args__ = (db.PrimaryKeyConstraint("id", name="tenant_pkey"),) __table_args__ = (db.PrimaryKeyConstraint("id", name="tenant_pkey"),)

+ 4
- 2
api/models/base.py Parādīt failu

from sqlalchemy.orm import declarative_base
from sqlalchemy.orm import DeclarativeBase


from models.engine import metadata from models.engine import metadata


Base = declarative_base(metadata=metadata)

class Base(DeclarativeBase):
metadata = metadata

+ 0
- 4
api/models/tools.py Parādīt failu

db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")
) )


@property
def schema_type(self) -> ApiProviderSchemaType:
return ApiProviderSchemaType.value_of(self.schema_type_str)

@property @property
def user(self) -> Account | None: def user(self) -> Account | None:
return db.session.query(Account).filter(Account.id == self.user_id).first() return db.session.query(Account).filter(Account.id == self.user_id).first()

+ 2
- 2
api/models/workflow.py Parādīt failu

from collections.abc import Mapping, Sequence from collections.abc import Mapping, Sequence
from datetime import UTC, datetime from datetime import UTC, datetime
from enum import Enum, StrEnum from enum import Enum, StrEnum
from typing import TYPE_CHECKING, Any, Optional, Self, Union
from typing import TYPE_CHECKING, Any, Optional, Union
from uuid import uuid4 from uuid import uuid4


from core.variables import utils as variable_utils from core.variables import utils as variable_utils
conversation_variables: Sequence[Variable], conversation_variables: Sequence[Variable],
marked_name: str = "", marked_name: str = "",
marked_comment: str = "", marked_comment: str = "",
) -> Self:
) -> "Workflow":
workflow = Workflow() workflow = Workflow()
workflow.id = str(uuid4()) workflow.id = str(uuid4())
workflow.tenant_id = tenant_id workflow.tenant_id = tenant_id

+ 8
- 7
api/services/vector_service.py Parādīt failu

): ):
documents: list[Document] = [] documents: list[Document] = []


document: Document | None = None
for segment in segments: for segment in segments:
if doc_form == IndexType.PARENT_CHILD_INDEX: if doc_form == IndexType.PARENT_CHILD_INDEX:
document = db.session.query(DatasetDocument).filter_by(id=segment.document_id).first()
if not document:
dataset_document = db.session.query(DatasetDocument).filter_by(id=segment.document_id).first()
if not dataset_document:
_logger.warning( _logger.warning(
"Expected DatasetDocument record to exist, but none was found, document_id=%s, segment_id=%s", "Expected DatasetDocument record to exist, but none was found, document_id=%s, segment_id=%s",
segment.document_id, segment.document_id,
# get the process rule # get the process rule
processing_rule = ( processing_rule = (
db.session.query(DatasetProcessRule) db.session.query(DatasetProcessRule)
.filter(DatasetProcessRule.id == document.dataset_process_rule_id)
.filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id)
.first() .first()
) )
if not processing_rule: if not processing_rule:
) )
else: else:
raise ValueError("The knowledge base index technique is not high quality!") raise ValueError("The knowledge base index technique is not high quality!")
cls.generate_child_chunks(segment, document, dataset, embedding_model_instance, processing_rule, False)
cls.generate_child_chunks(
segment, dataset_document, dataset, embedding_model_instance, processing_rule, False
)
else: else:
document = Document(
rag_document = Document(
page_content=segment.content, page_content=segment.content,
metadata={ metadata={
"doc_id": segment.index_node_id, "doc_id": segment.index_node_id,
"dataset_id": segment.dataset_id, "dataset_id": segment.dataset_id,
}, },
) )
documents.append(document)
documents.append(rag_document)
if len(documents) > 0: if len(documents) > 0:
index_processor = IndexProcessorFactory(doc_form).init_index_processor() index_processor = IndexProcessorFactory(doc_form).init_index_processor()
index_processor.load(dataset, documents, with_keywords=True, keywords_list=keywords_list) index_processor.load(dataset, documents, with_keywords=True, keywords_list=keywords_list)

+ 3
- 3
api/services/workflow_service.py Parādīt failu

raise DraftWorkflowDeletionError("Cannot delete draft workflow versions") raise DraftWorkflowDeletionError("Cannot delete draft workflow versions")


# Check if this workflow is currently referenced by an app # Check if this workflow is currently referenced by an app
stmt = select(App).where(App.workflow_id == workflow_id)
app = session.scalar(stmt)
app_stmt = select(App).where(App.workflow_id == workflow_id)
app = session.scalar(app_stmt)
if app: if app:
# Cannot delete a workflow that's currently in use by an app # Cannot delete a workflow that's currently in use by an app
raise WorkflowInUseError(f"Cannot delete workflow that is currently in use by app '{app.name}'")
raise WorkflowInUseError(f"Cannot delete workflow that is currently in use by app '{app.id}'")


# Don't use workflow.tool_published as it's not accurate for specific workflow versions # Don't use workflow.tool_published as it's not accurate for specific workflow versions
# Check if there's a tool provider using this specific workflow version # Check if there's a tool provider using this specific workflow version

+ 1
- 1
api/tasks/add_document_to_index_task.py Parādīt failu

logging.exception("add document to index failed") logging.exception("add document to index failed")
dataset_document.enabled = False dataset_document.enabled = False
dataset_document.disabled_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) dataset_document.disabled_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
dataset_document.status = "error"
dataset_document.indexing_status = "error"
dataset_document.error = str(e) dataset_document.error = str(e)
db.session.commit() db.session.commit()
finally: finally:

+ 1
- 1
api/tasks/remove_app_and_related_data_task.py Parādīt failu

def _delete_app_workflow_node_executions(tenant_id: str, app_id: str): def _delete_app_workflow_node_executions(tenant_id: str, app_id: str):
# Get app's owner # Get app's owner
with Session(db.engine, expire_on_commit=False) as session: with Session(db.engine, expire_on_commit=False) as session:
stmt = select(Account).where(Account.id == App.owner_id).where(App.id == app_id)
stmt = select(Account).where(Account.id == App.created_by).where(App.id == app_id)
user = session.scalar(stmt) user = session.scalar(stmt)


if user is None: if user is None:

+ 1
- 1
api/tests/unit_tests/core/tools/workflow_as_tool/test_tool.py Parādīt failu

# needs to patch those methods to avoid database access. # needs to patch those methods to avoid database access.
monkeypatch.setattr(tool, "_get_app", lambda *args, **kwargs: None) monkeypatch.setattr(tool, "_get_app", lambda *args, **kwargs: None)
monkeypatch.setattr(tool, "_get_workflow", lambda *args, **kwargs: None) monkeypatch.setattr(tool, "_get_workflow", lambda *args, **kwargs: None)
monkeypatch.setattr(tool, "_get_user", lambda *args, **kwargs: None)


# replace `WorkflowAppGenerator.generate` 's return value. # replace `WorkflowAppGenerator.generate` 's return value.
monkeypatch.setattr( monkeypatch.setattr(
"core.app.apps.workflow.app_generator.WorkflowAppGenerator.generate", "core.app.apps.workflow.app_generator.WorkflowAppGenerator.generate",
lambda *args, **kwargs: {"data": {"error": "oops"}}, lambda *args, **kwargs: {"data": {"error": "oops"}},
) )
monkeypatch.setattr("flask_login.current_user", lambda *args, **kwargs: None)


with pytest.raises(ToolInvokeError) as exc_info: with pytest.raises(ToolInvokeError) as exc_info:
# WorkflowTool always returns a generator, so we need to iterate to # WorkflowTool always returns a generator, so we need to iterate to

Notiek ielāde…
Atcelt
Saglabāt