Co-authored-by: -LAN- <laipz8200@outlook.com>tags/0.7.2
| @@ -74,7 +74,6 @@ exclude = [ | |||
| "controllers/**/*.py", | |||
| "models/**/*.py", | |||
| "migrations/**/*", | |||
| "services/**/*.py", | |||
| ] | |||
| [tool.pytest_env] | |||
| @@ -1,3 +1,3 @@ | |||
| from . import errors | |||
| __all__ = ['errors'] | |||
| __all__ = ["errors"] | |||
| @@ -39,12 +39,7 @@ from tasks.mail_reset_password_task import send_reset_password_mail_task | |||
| class AccountService: | |||
| reset_password_rate_limiter = RateLimiter( | |||
| prefix="reset_password_rate_limit", | |||
| max_attempts=5, | |||
| time_window=60 * 60 | |||
| ) | |||
| reset_password_rate_limiter = RateLimiter(prefix="reset_password_rate_limit", max_attempts=5, time_window=60 * 60) | |||
| @staticmethod | |||
| def load_user(user_id: str) -> None | Account: | |||
| @@ -55,12 +50,15 @@ class AccountService: | |||
| if account.status in [AccountStatus.BANNED.value, AccountStatus.CLOSED.value]: | |||
| raise Unauthorized("Account is banned or closed.") | |||
| current_tenant: TenantAccountJoin = TenantAccountJoin.query.filter_by(account_id=account.id, current=True).first() | |||
| current_tenant: TenantAccountJoin = TenantAccountJoin.query.filter_by( | |||
| account_id=account.id, current=True | |||
| ).first() | |||
| if current_tenant: | |||
| account.current_tenant_id = current_tenant.tenant_id | |||
| else: | |||
| available_ta = TenantAccountJoin.query.filter_by(account_id=account.id) \ | |||
| .order_by(TenantAccountJoin.id.asc()).first() | |||
| available_ta = ( | |||
| TenantAccountJoin.query.filter_by(account_id=account.id).order_by(TenantAccountJoin.id.asc()).first() | |||
| ) | |||
| if not available_ta: | |||
| return None | |||
| @@ -74,14 +72,13 @@ class AccountService: | |||
| return account | |||
| @staticmethod | |||
| def get_account_jwt_token(account, *, exp: timedelta = timedelta(days=30)): | |||
| payload = { | |||
| "user_id": account.id, | |||
| "exp": datetime.now(timezone.utc).replace(tzinfo=None) + exp, | |||
| "iss": dify_config.EDITION, | |||
| "sub": 'Console API Passport', | |||
| "sub": "Console API Passport", | |||
| } | |||
| token = PassportService().issue(payload) | |||
| @@ -93,10 +90,10 @@ class AccountService: | |||
| account = Account.query.filter_by(email=email).first() | |||
| if not account: | |||
| raise AccountLoginError('Invalid email or password.') | |||
| raise AccountLoginError("Invalid email or password.") | |||
| if account.status == AccountStatus.BANNED.value or account.status == AccountStatus.CLOSED.value: | |||
| raise AccountLoginError('Account is banned or closed.') | |||
| raise AccountLoginError("Account is banned or closed.") | |||
| if account.status == AccountStatus.PENDING.value: | |||
| account.status = AccountStatus.ACTIVE.value | |||
| @@ -104,7 +101,7 @@ class AccountService: | |||
| db.session.commit() | |||
| if account.password is None or not compare_password(password, account.password, account.password_salt): | |||
| raise AccountLoginError('Invalid email or password.') | |||
| raise AccountLoginError("Invalid email or password.") | |||
| return account | |||
| @staticmethod | |||
| @@ -129,11 +126,9 @@ class AccountService: | |||
| return account | |||
| @staticmethod | |||
| def create_account(email: str, | |||
| name: str, | |||
| interface_language: str, | |||
| password: Optional[str] = None, | |||
| interface_theme: str = 'light') -> Account: | |||
| def create_account( | |||
| email: str, name: str, interface_language: str, password: Optional[str] = None, interface_theme: str = "light" | |||
| ) -> Account: | |||
| """create account""" | |||
| account = Account() | |||
| account.email = email | |||
| @@ -155,7 +150,7 @@ class AccountService: | |||
| account.interface_theme = interface_theme | |||
| # Set timezone based on language | |||
| account.timezone = language_timezone_mapping.get(interface_language, 'UTC') | |||
| account.timezone = language_timezone_mapping.get(interface_language, "UTC") | |||
| db.session.add(account) | |||
| db.session.commit() | |||
| @@ -166,8 +161,9 @@ class AccountService: | |||
| """Link account integrate""" | |||
| try: | |||
| # Query whether there is an existing binding record for the same provider | |||
| account_integrate: Optional[AccountIntegrate] = AccountIntegrate.query.filter_by(account_id=account.id, | |||
| provider=provider).first() | |||
| account_integrate: Optional[AccountIntegrate] = AccountIntegrate.query.filter_by( | |||
| account_id=account.id, provider=provider | |||
| ).first() | |||
| if account_integrate: | |||
| # If it exists, update the record | |||
| @@ -176,15 +172,16 @@ class AccountService: | |||
| account_integrate.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) | |||
| else: | |||
| # If it does not exist, create a new record | |||
| account_integrate = AccountIntegrate(account_id=account.id, provider=provider, open_id=open_id, | |||
| encrypted_token="") | |||
| account_integrate = AccountIntegrate( | |||
| account_id=account.id, provider=provider, open_id=open_id, encrypted_token="" | |||
| ) | |||
| db.session.add(account_integrate) | |||
| db.session.commit() | |||
| logging.info(f'Account {account.id} linked {provider} account {open_id}.') | |||
| logging.info(f"Account {account.id} linked {provider} account {open_id}.") | |||
| except Exception as e: | |||
| logging.exception(f'Failed to link {provider} account {open_id} to Account {account.id}') | |||
| raise LinkAccountIntegrateError('Failed to link account.') from e | |||
| logging.exception(f"Failed to link {provider} account {open_id} to Account {account.id}") | |||
| raise LinkAccountIntegrateError("Failed to link account.") from e | |||
| @staticmethod | |||
| def close_account(account: Account) -> None: | |||
| @@ -218,7 +215,7 @@ class AccountService: | |||
| AccountService.update_last_login(account, ip_address=ip_address) | |||
| exp = timedelta(days=30) | |||
| token = AccountService.get_account_jwt_token(account, exp=exp) | |||
| redis_client.set(_get_login_cache_key(account_id=account.id, token=token), '1', ex=int(exp.total_seconds())) | |||
| redis_client.set(_get_login_cache_key(account_id=account.id, token=token), "1", ex=int(exp.total_seconds())) | |||
| return token | |||
| @staticmethod | |||
| @@ -236,22 +233,18 @@ class AccountService: | |||
| if cls.reset_password_rate_limiter.is_rate_limited(account.email): | |||
| raise RateLimitExceededError(f"Rate limit exceeded for email: {account.email}. Please try again later.") | |||
| token = TokenManager.generate_token(account, 'reset_password') | |||
| send_reset_password_mail_task.delay( | |||
| language=account.interface_language, | |||
| to=account.email, | |||
| token=token | |||
| ) | |||
| token = TokenManager.generate_token(account, "reset_password") | |||
| send_reset_password_mail_task.delay(language=account.interface_language, to=account.email, token=token) | |||
| cls.reset_password_rate_limiter.increment_rate_limit(account.email) | |||
| return token | |||
| @classmethod | |||
| def revoke_reset_password_token(cls, token: str): | |||
| TokenManager.revoke_token(token, 'reset_password') | |||
| TokenManager.revoke_token(token, "reset_password") | |||
| @classmethod | |||
| def get_reset_password_data(cls, token: str) -> Optional[dict[str, Any]]: | |||
| return TokenManager.get_token_data(token, 'reset_password') | |||
| return TokenManager.get_token_data(token, "reset_password") | |||
| def _get_login_cache_key(*, account_id: str, token: str): | |||
| @@ -259,7 +252,6 @@ def _get_login_cache_key(*, account_id: str, token: str): | |||
| class TenantService: | |||
| @staticmethod | |||
| def create_tenant(name: str) -> Tenant: | |||
| """Create tenant""" | |||
| @@ -275,31 +267,28 @@ class TenantService: | |||
| @staticmethod | |||
| def create_owner_tenant_if_not_exist(account: Account): | |||
| """Create owner tenant if not exist""" | |||
| available_ta = TenantAccountJoin.query.filter_by(account_id=account.id) \ | |||
| .order_by(TenantAccountJoin.id.asc()).first() | |||
| available_ta = ( | |||
| TenantAccountJoin.query.filter_by(account_id=account.id).order_by(TenantAccountJoin.id.asc()).first() | |||
| ) | |||
| if available_ta: | |||
| return | |||
| tenant = TenantService.create_tenant(f"{account.name}'s Workspace") | |||
| TenantService.create_tenant_member(tenant, account, role='owner') | |||
| TenantService.create_tenant_member(tenant, account, role="owner") | |||
| account.current_tenant = tenant | |||
| db.session.commit() | |||
| tenant_was_created.send(tenant) | |||
| @staticmethod | |||
| def create_tenant_member(tenant: Tenant, account: Account, role: str = 'normal') -> TenantAccountJoin: | |||
| def create_tenant_member(tenant: Tenant, account: Account, role: str = "normal") -> TenantAccountJoin: | |||
| """Create tenant member""" | |||
| if role == TenantAccountJoinRole.OWNER.value: | |||
| if TenantService.has_roles(tenant, [TenantAccountJoinRole.OWNER]): | |||
| logging.error(f'Tenant {tenant.id} has already an owner.') | |||
| raise Exception('Tenant already has an owner.') | |||
| logging.error(f"Tenant {tenant.id} has already an owner.") | |||
| raise Exception("Tenant already has an owner.") | |||
| ta = TenantAccountJoin( | |||
| tenant_id=tenant.id, | |||
| account_id=account.id, | |||
| role=role | |||
| ) | |||
| ta = TenantAccountJoin(tenant_id=tenant.id, account_id=account.id, role=role) | |||
| db.session.add(ta) | |||
| db.session.commit() | |||
| return ta | |||
| @@ -307,9 +296,12 @@ class TenantService: | |||
| @staticmethod | |||
| def get_join_tenants(account: Account) -> list[Tenant]: | |||
| """Get account join tenants""" | |||
| return db.session.query(Tenant).join( | |||
| TenantAccountJoin, Tenant.id == TenantAccountJoin.tenant_id | |||
| ).filter(TenantAccountJoin.account_id == account.id, Tenant.status == TenantStatus.NORMAL).all() | |||
| return ( | |||
| db.session.query(Tenant) | |||
| .join(TenantAccountJoin, Tenant.id == TenantAccountJoin.tenant_id) | |||
| .filter(TenantAccountJoin.account_id == account.id, Tenant.status == TenantStatus.NORMAL) | |||
| .all() | |||
| ) | |||
| @staticmethod | |||
| def get_current_tenant_by_account(account: Account): | |||
| @@ -333,16 +325,23 @@ class TenantService: | |||
| if tenant_id is None: | |||
| raise ValueError("Tenant ID must be provided.") | |||
| tenant_account_join = db.session.query(TenantAccountJoin).join(Tenant, TenantAccountJoin.tenant_id == Tenant.id).filter( | |||
| TenantAccountJoin.account_id == account.id, | |||
| TenantAccountJoin.tenant_id == tenant_id, | |||
| Tenant.status == TenantStatus.NORMAL, | |||
| ).first() | |||
| tenant_account_join = ( | |||
| db.session.query(TenantAccountJoin) | |||
| .join(Tenant, TenantAccountJoin.tenant_id == Tenant.id) | |||
| .filter( | |||
| TenantAccountJoin.account_id == account.id, | |||
| TenantAccountJoin.tenant_id == tenant_id, | |||
| Tenant.status == TenantStatus.NORMAL, | |||
| ) | |||
| .first() | |||
| ) | |||
| if not tenant_account_join: | |||
| raise AccountNotLinkTenantError("Tenant not found or account is not a member of the tenant.") | |||
| else: | |||
| TenantAccountJoin.query.filter(TenantAccountJoin.account_id == account.id, TenantAccountJoin.tenant_id != tenant_id).update({'current': False}) | |||
| TenantAccountJoin.query.filter( | |||
| TenantAccountJoin.account_id == account.id, TenantAccountJoin.tenant_id != tenant_id | |||
| ).update({"current": False}) | |||
| tenant_account_join.current = True | |||
| # Set the current tenant for the account | |||
| account.current_tenant_id = tenant_account_join.tenant_id | |||
| @@ -354,9 +353,7 @@ class TenantService: | |||
| query = ( | |||
| db.session.query(Account, TenantAccountJoin.role) | |||
| .select_from(Account) | |||
| .join( | |||
| TenantAccountJoin, Account.id == TenantAccountJoin.account_id | |||
| ) | |||
| .join(TenantAccountJoin, Account.id == TenantAccountJoin.account_id) | |||
| .filter(TenantAccountJoin.tenant_id == tenant.id) | |||
| ) | |||
| @@ -375,11 +372,9 @@ class TenantService: | |||
| query = ( | |||
| db.session.query(Account, TenantAccountJoin.role) | |||
| .select_from(Account) | |||
| .join( | |||
| TenantAccountJoin, Account.id == TenantAccountJoin.account_id | |||
| ) | |||
| .join(TenantAccountJoin, Account.id == TenantAccountJoin.account_id) | |||
| .filter(TenantAccountJoin.tenant_id == tenant.id) | |||
| .filter(TenantAccountJoin.role == 'dataset_operator') | |||
| .filter(TenantAccountJoin.role == "dataset_operator") | |||
| ) | |||
| # Initialize an empty list to store the updated accounts | |||
| @@ -395,20 +390,25 @@ class TenantService: | |||
| def has_roles(tenant: Tenant, roles: list[TenantAccountJoinRole]) -> bool: | |||
| """Check if user has any of the given roles for a tenant""" | |||
| if not all(isinstance(role, TenantAccountJoinRole) for role in roles): | |||
| raise ValueError('all roles must be TenantAccountJoinRole') | |||
| raise ValueError("all roles must be TenantAccountJoinRole") | |||
| return db.session.query(TenantAccountJoin).filter( | |||
| TenantAccountJoin.tenant_id == tenant.id, | |||
| TenantAccountJoin.role.in_([role.value for role in roles]) | |||
| ).first() is not None | |||
| return ( | |||
| db.session.query(TenantAccountJoin) | |||
| .filter( | |||
| TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.role.in_([role.value for role in roles]) | |||
| ) | |||
| .first() | |||
| is not None | |||
| ) | |||
| @staticmethod | |||
| def get_user_role(account: Account, tenant: Tenant) -> Optional[TenantAccountJoinRole]: | |||
| """Get the role of the current account for a given tenant""" | |||
| join = db.session.query(TenantAccountJoin).filter( | |||
| TenantAccountJoin.tenant_id == tenant.id, | |||
| TenantAccountJoin.account_id == account.id | |||
| ).first() | |||
| join = ( | |||
| db.session.query(TenantAccountJoin) | |||
| .filter(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == account.id) | |||
| .first() | |||
| ) | |||
| return join.role if join else None | |||
| @staticmethod | |||
| @@ -420,29 +420,26 @@ class TenantService: | |||
| def check_member_permission(tenant: Tenant, operator: Account, member: Account, action: str) -> None: | |||
| """Check member permission""" | |||
| perms = { | |||
| 'add': [TenantAccountRole.OWNER, TenantAccountRole.ADMIN], | |||
| 'remove': [TenantAccountRole.OWNER], | |||
| 'update': [TenantAccountRole.OWNER] | |||
| "add": [TenantAccountRole.OWNER, TenantAccountRole.ADMIN], | |||
| "remove": [TenantAccountRole.OWNER], | |||
| "update": [TenantAccountRole.OWNER], | |||
| } | |||
| if action not in ['add', 'remove', 'update']: | |||
| if action not in ["add", "remove", "update"]: | |||
| raise InvalidActionError("Invalid action.") | |||
| if member: | |||
| if operator.id == member.id: | |||
| raise CannotOperateSelfError("Cannot operate self.") | |||
| ta_operator = TenantAccountJoin.query.filter_by( | |||
| tenant_id=tenant.id, | |||
| account_id=operator.id | |||
| ).first() | |||
| ta_operator = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, account_id=operator.id).first() | |||
| if not ta_operator or ta_operator.role not in perms[action]: | |||
| raise NoPermissionError(f'No permission to {action} member.') | |||
| raise NoPermissionError(f"No permission to {action} member.") | |||
| @staticmethod | |||
| def remove_member_from_tenant(tenant: Tenant, account: Account, operator: Account) -> None: | |||
| """Remove member from tenant""" | |||
| if operator.id == account.id and TenantService.check_member_permission(tenant, operator, account, 'remove'): | |||
| if operator.id == account.id and TenantService.check_member_permission(tenant, operator, account, "remove"): | |||
| raise CannotOperateSelfError("Cannot operate self.") | |||
| ta = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, account_id=account.id).first() | |||
| @@ -455,23 +452,17 @@ class TenantService: | |||
| @staticmethod | |||
| def update_member_role(tenant: Tenant, member: Account, new_role: str, operator: Account) -> None: | |||
| """Update member role""" | |||
| TenantService.check_member_permission(tenant, operator, member, 'update') | |||
| TenantService.check_member_permission(tenant, operator, member, "update") | |||
| target_member_join = TenantAccountJoin.query.filter_by( | |||
| tenant_id=tenant.id, | |||
| account_id=member.id | |||
| ).first() | |||
| target_member_join = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, account_id=member.id).first() | |||
| if target_member_join.role == new_role: | |||
| raise RoleAlreadyAssignedError("The provided role is already assigned to the member.") | |||
| if new_role == 'owner': | |||
| if new_role == "owner": | |||
| # Find the current owner and change their role to 'admin' | |||
| current_owner_join = TenantAccountJoin.query.filter_by( | |||
| tenant_id=tenant.id, | |||
| role='owner' | |||
| ).first() | |||
| current_owner_join.role = 'admin' | |||
| current_owner_join = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, role="owner").first() | |||
| current_owner_join.role = "admin" | |||
| # Update the role of the target member | |||
| target_member_join.role = new_role | |||
| @@ -480,8 +471,8 @@ class TenantService: | |||
| @staticmethod | |||
| def dissolve_tenant(tenant: Tenant, operator: Account) -> None: | |||
| """Dissolve tenant""" | |||
| if not TenantService.check_member_permission(tenant, operator, operator, 'remove'): | |||
| raise NoPermissionError('No permission to dissolve tenant.') | |||
| if not TenantService.check_member_permission(tenant, operator, operator, "remove"): | |||
| raise NoPermissionError("No permission to dissolve tenant.") | |||
| db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id).delete() | |||
| db.session.delete(tenant) | |||
| db.session.commit() | |||
| @@ -494,10 +485,9 @@ class TenantService: | |||
| class RegisterService: | |||
| @classmethod | |||
| def _get_invitation_token_key(cls, token: str) -> str: | |||
| return f'member_invite:token:{token}' | |||
| return f"member_invite:token:{token}" | |||
| @classmethod | |||
| def setup(cls, email: str, name: str, password: str, ip_address: str) -> None: | |||
| @@ -523,9 +513,7 @@ class RegisterService: | |||
| TenantService.create_owner_tenant_if_not_exist(account) | |||
| dify_setup = DifySetup( | |||
| version=dify_config.CURRENT_VERSION | |||
| ) | |||
| dify_setup = DifySetup(version=dify_config.CURRENT_VERSION) | |||
| db.session.add(dify_setup) | |||
| db.session.commit() | |||
| except Exception as e: | |||
| @@ -535,34 +523,35 @@ class RegisterService: | |||
| db.session.query(Tenant).delete() | |||
| db.session.commit() | |||
| logging.exception(f'Setup failed: {e}') | |||
| raise ValueError(f'Setup failed: {e}') | |||
| logging.exception(f"Setup failed: {e}") | |||
| raise ValueError(f"Setup failed: {e}") | |||
| @classmethod | |||
| def register(cls, email, name, | |||
| password: Optional[str] = None, | |||
| open_id: Optional[str] = None, | |||
| provider: Optional[str] = None, | |||
| language: Optional[str] = None, | |||
| status: Optional[AccountStatus] = None) -> Account: | |||
| def register( | |||
| cls, | |||
| email, | |||
| name, | |||
| password: Optional[str] = None, | |||
| open_id: Optional[str] = None, | |||
| provider: Optional[str] = None, | |||
| language: Optional[str] = None, | |||
| status: Optional[AccountStatus] = None, | |||
| ) -> Account: | |||
| db.session.begin_nested() | |||
| """Register account""" | |||
| try: | |||
| account = AccountService.create_account( | |||
| email=email, | |||
| name=name, | |||
| interface_language=language if language else languages[0], | |||
| password=password | |||
| email=email, name=name, interface_language=language if language else languages[0], password=password | |||
| ) | |||
| account.status = AccountStatus.ACTIVE.value if not status else status.value | |||
| account.initialized_at = datetime.now(timezone.utc).replace(tzinfo=None) | |||
| if open_id is not None or provider is not None: | |||
| AccountService.link_account_integrate(provider, open_id, account) | |||
| if dify_config.EDITION != 'SELF_HOSTED': | |||
| if dify_config.EDITION != "SELF_HOSTED": | |||
| tenant = TenantService.create_tenant(f"{account.name}'s Workspace") | |||
| TenantService.create_tenant_member(tenant, account, role='owner') | |||
| TenantService.create_tenant_member(tenant, account, role="owner") | |||
| account.current_tenant = tenant | |||
| tenant_was_created.send(tenant) | |||
| @@ -570,30 +559,29 @@ class RegisterService: | |||
| db.session.commit() | |||
| except Exception as e: | |||
| db.session.rollback() | |||
| logging.error(f'Register failed: {e}') | |||
| raise AccountRegisterError(f'Registration failed: {e}') from e | |||
| logging.error(f"Register failed: {e}") | |||
| raise AccountRegisterError(f"Registration failed: {e}") from e | |||
| return account | |||
| @classmethod | |||
| def invite_new_member(cls, tenant: Tenant, email: str, language: str, role: str = 'normal', inviter: Account = None) -> str: | |||
| def invite_new_member( | |||
| cls, tenant: Tenant, email: str, language: str, role: str = "normal", inviter: Account = None | |||
| ) -> str: | |||
| """Invite new member""" | |||
| account = Account.query.filter_by(email=email).first() | |||
| if not account: | |||
| TenantService.check_member_permission(tenant, inviter, None, 'add') | |||
| name = email.split('@')[0] | |||
| TenantService.check_member_permission(tenant, inviter, None, "add") | |||
| name = email.split("@")[0] | |||
| account = cls.register(email=email, name=name, language=language, status=AccountStatus.PENDING) | |||
| # Create new tenant member for invited tenant | |||
| TenantService.create_tenant_member(tenant, account, role) | |||
| TenantService.switch_tenant(account, tenant.id) | |||
| else: | |||
| TenantService.check_member_permission(tenant, inviter, account, 'add') | |||
| ta = TenantAccountJoin.query.filter_by( | |||
| tenant_id=tenant.id, | |||
| account_id=account.id | |||
| ).first() | |||
| TenantService.check_member_permission(tenant, inviter, account, "add") | |||
| ta = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, account_id=account.id).first() | |||
| if not ta: | |||
| TenantService.create_tenant_member(tenant, account, role) | |||
| @@ -609,7 +597,7 @@ class RegisterService: | |||
| language=account.interface_language, | |||
| to=email, | |||
| token=token, | |||
| inviter_name=inviter.name if inviter else 'Dify', | |||
| inviter_name=inviter.name if inviter else "Dify", | |||
| workspace_name=tenant.name, | |||
| ) | |||
| @@ -619,23 +607,19 @@ class RegisterService: | |||
| def generate_invite_token(cls, tenant: Tenant, account: Account) -> str: | |||
| token = str(uuid.uuid4()) | |||
| invitation_data = { | |||
| 'account_id': account.id, | |||
| 'email': account.email, | |||
| 'workspace_id': tenant.id, | |||
| "account_id": account.id, | |||
| "email": account.email, | |||
| "workspace_id": tenant.id, | |||
| } | |||
| expiryHours = dify_config.INVITE_EXPIRY_HOURS | |||
| redis_client.setex( | |||
| cls._get_invitation_token_key(token), | |||
| expiryHours * 60 * 60, | |||
| json.dumps(invitation_data) | |||
| ) | |||
| redis_client.setex(cls._get_invitation_token_key(token), expiryHours * 60 * 60, json.dumps(invitation_data)) | |||
| return token | |||
| @classmethod | |||
| def revoke_token(cls, workspace_id: str, email: str, token: str): | |||
| if workspace_id and email: | |||
| email_hash = sha256(email.encode()).hexdigest() | |||
| cache_key = 'member_invite_token:{}, {}:{}'.format(workspace_id, email_hash, token) | |||
| cache_key = "member_invite_token:{}, {}:{}".format(workspace_id, email_hash, token) | |||
| redis_client.delete(cache_key) | |||
| else: | |||
| redis_client.delete(cls._get_invitation_token_key(token)) | |||
| @@ -646,17 +630,21 @@ class RegisterService: | |||
| if not invitation_data: | |||
| return None | |||
| tenant = db.session.query(Tenant).filter( | |||
| Tenant.id == invitation_data['workspace_id'], | |||
| Tenant.status == 'normal' | |||
| ).first() | |||
| tenant = ( | |||
| db.session.query(Tenant) | |||
| .filter(Tenant.id == invitation_data["workspace_id"], Tenant.status == "normal") | |||
| .first() | |||
| ) | |||
| if not tenant: | |||
| return None | |||
| tenant_account = db.session.query(Account, TenantAccountJoin.role).join( | |||
| TenantAccountJoin, Account.id == TenantAccountJoin.account_id | |||
| ).filter(Account.email == invitation_data['email'], TenantAccountJoin.tenant_id == tenant.id).first() | |||
| tenant_account = ( | |||
| db.session.query(Account, TenantAccountJoin.role) | |||
| .join(TenantAccountJoin, Account.id == TenantAccountJoin.account_id) | |||
| .filter(Account.email == invitation_data["email"], TenantAccountJoin.tenant_id == tenant.id) | |||
| .first() | |||
| ) | |||
| if not tenant_account: | |||
| return None | |||
| @@ -665,29 +653,29 @@ class RegisterService: | |||
| if not account: | |||
| return None | |||
| if invitation_data['account_id'] != str(account.id): | |||
| if invitation_data["account_id"] != str(account.id): | |||
| return None | |||
| return { | |||
| 'account': account, | |||
| 'data': invitation_data, | |||
| 'tenant': tenant, | |||
| "account": account, | |||
| "data": invitation_data, | |||
| "tenant": tenant, | |||
| } | |||
| @classmethod | |||
| def _get_invitation_by_token(cls, token: str, workspace_id: str, email: str) -> Optional[dict[str, str]]: | |||
| if workspace_id is not None and email is not None: | |||
| email_hash = sha256(email.encode()).hexdigest() | |||
| cache_key = f'member_invite_token:{workspace_id}, {email_hash}:{token}' | |||
| cache_key = f"member_invite_token:{workspace_id}, {email_hash}:{token}" | |||
| account_id = redis_client.get(cache_key) | |||
| if not account_id: | |||
| return None | |||
| return { | |||
| 'account_id': account_id.decode('utf-8'), | |||
| 'email': email, | |||
| 'workspace_id': workspace_id, | |||
| "account_id": account_id.decode("utf-8"), | |||
| "email": email, | |||
| "workspace_id": workspace_id, | |||
| } | |||
| else: | |||
| data = redis_client.get(cls._get_invitation_token_key(token)) | |||
| @@ -1,4 +1,3 @@ | |||
| import copy | |||
| from core.prompt.prompt_templates.advanced_prompt_templates import ( | |||
| @@ -17,59 +16,78 @@ from models.model import AppMode | |||
| class AdvancedPromptTemplateService: | |||
| @classmethod | |||
| def get_prompt(cls, args: dict) -> dict: | |||
| app_mode = args['app_mode'] | |||
| model_mode = args['model_mode'] | |||
| model_name = args['model_name'] | |||
| has_context = args['has_context'] | |||
| app_mode = args["app_mode"] | |||
| model_mode = args["model_mode"] | |||
| model_name = args["model_name"] | |||
| has_context = args["has_context"] | |||
| if 'baichuan' in model_name.lower(): | |||
| if "baichuan" in model_name.lower(): | |||
| return cls.get_baichuan_prompt(app_mode, model_mode, has_context) | |||
| else: | |||
| return cls.get_common_prompt(app_mode, model_mode, has_context) | |||
| @classmethod | |||
| def get_common_prompt(cls, app_mode: str, model_mode:str, has_context: str) -> dict: | |||
| def get_common_prompt(cls, app_mode: str, model_mode: str, has_context: str) -> dict: | |||
| context_prompt = copy.deepcopy(CONTEXT) | |||
| if app_mode == AppMode.CHAT.value: | |||
| if model_mode == "completion": | |||
| return cls.get_completion_prompt(copy.deepcopy(CHAT_APP_COMPLETION_PROMPT_CONFIG), has_context, context_prompt) | |||
| return cls.get_completion_prompt( | |||
| copy.deepcopy(CHAT_APP_COMPLETION_PROMPT_CONFIG), has_context, context_prompt | |||
| ) | |||
| elif model_mode == "chat": | |||
| return cls.get_chat_prompt(copy.deepcopy(CHAT_APP_CHAT_PROMPT_CONFIG), has_context, context_prompt) | |||
| elif app_mode == AppMode.COMPLETION.value: | |||
| if model_mode == "completion": | |||
| return cls.get_completion_prompt(copy.deepcopy(COMPLETION_APP_COMPLETION_PROMPT_CONFIG), has_context, context_prompt) | |||
| return cls.get_completion_prompt( | |||
| copy.deepcopy(COMPLETION_APP_COMPLETION_PROMPT_CONFIG), has_context, context_prompt | |||
| ) | |||
| elif model_mode == "chat": | |||
| return cls.get_chat_prompt(copy.deepcopy(COMPLETION_APP_CHAT_PROMPT_CONFIG), has_context, context_prompt) | |||
| return cls.get_chat_prompt( | |||
| copy.deepcopy(COMPLETION_APP_CHAT_PROMPT_CONFIG), has_context, context_prompt | |||
| ) | |||
| @classmethod | |||
| def get_completion_prompt(cls, prompt_template: dict, has_context: str, context: str) -> dict: | |||
| if has_context == 'true': | |||
| prompt_template['completion_prompt_config']['prompt']['text'] = context + prompt_template['completion_prompt_config']['prompt']['text'] | |||
| if has_context == "true": | |||
| prompt_template["completion_prompt_config"]["prompt"]["text"] = ( | |||
| context + prompt_template["completion_prompt_config"]["prompt"]["text"] | |||
| ) | |||
| return prompt_template | |||
| @classmethod | |||
| def get_chat_prompt(cls, prompt_template: dict, has_context: str, context: str) -> dict: | |||
| if has_context == 'true': | |||
| prompt_template['chat_prompt_config']['prompt'][0]['text'] = context + prompt_template['chat_prompt_config']['prompt'][0]['text'] | |||
| if has_context == "true": | |||
| prompt_template["chat_prompt_config"]["prompt"][0]["text"] = ( | |||
| context + prompt_template["chat_prompt_config"]["prompt"][0]["text"] | |||
| ) | |||
| return prompt_template | |||
| @classmethod | |||
| def get_baichuan_prompt(cls, app_mode: str, model_mode:str, has_context: str) -> dict: | |||
| def get_baichuan_prompt(cls, app_mode: str, model_mode: str, has_context: str) -> dict: | |||
| baichuan_context_prompt = copy.deepcopy(BAICHUAN_CONTEXT) | |||
| if app_mode == AppMode.CHAT.value: | |||
| if model_mode == "completion": | |||
| return cls.get_completion_prompt(copy.deepcopy(BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG), has_context, baichuan_context_prompt) | |||
| return cls.get_completion_prompt( | |||
| copy.deepcopy(BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG), has_context, baichuan_context_prompt | |||
| ) | |||
| elif model_mode == "chat": | |||
| return cls.get_chat_prompt(copy.deepcopy(BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG), has_context, baichuan_context_prompt) | |||
| return cls.get_chat_prompt( | |||
| copy.deepcopy(BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG), has_context, baichuan_context_prompt | |||
| ) | |||
| elif app_mode == AppMode.COMPLETION.value: | |||
| if model_mode == "completion": | |||
| return cls.get_completion_prompt(copy.deepcopy(BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG), has_context, baichuan_context_prompt) | |||
| return cls.get_completion_prompt( | |||
| copy.deepcopy(BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG), | |||
| has_context, | |||
| baichuan_context_prompt, | |||
| ) | |||
| elif model_mode == "chat": | |||
| return cls.get_chat_prompt(copy.deepcopy(BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG), has_context, baichuan_context_prompt) | |||
| return cls.get_chat_prompt( | |||
| copy.deepcopy(BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG), has_context, baichuan_context_prompt | |||
| ) | |||
| @@ -10,59 +10,65 @@ from models.model import App, Conversation, EndUser, Message, MessageAgentThough | |||
| class AgentService: | |||
| @classmethod | |||
| def get_agent_logs(cls, app_model: App, | |||
| conversation_id: str, | |||
| message_id: str) -> dict: | |||
| def get_agent_logs(cls, app_model: App, conversation_id: str, message_id: str) -> dict: | |||
| """ | |||
| Service to get agent logs | |||
| """ | |||
| conversation: Conversation = db.session.query(Conversation).filter( | |||
| Conversation.id == conversation_id, | |||
| Conversation.app_id == app_model.id, | |||
| ).first() | |||
| conversation: Conversation = ( | |||
| db.session.query(Conversation) | |||
| .filter( | |||
| Conversation.id == conversation_id, | |||
| Conversation.app_id == app_model.id, | |||
| ) | |||
| .first() | |||
| ) | |||
| if not conversation: | |||
| raise ValueError(f"Conversation not found: {conversation_id}") | |||
| message: Message = db.session.query(Message).filter( | |||
| Message.id == message_id, | |||
| Message.conversation_id == conversation_id, | |||
| ).first() | |||
| message: Message = ( | |||
| db.session.query(Message) | |||
| .filter( | |||
| Message.id == message_id, | |||
| Message.conversation_id == conversation_id, | |||
| ) | |||
| .first() | |||
| ) | |||
| if not message: | |||
| raise ValueError(f"Message not found: {message_id}") | |||
| agent_thoughts: list[MessageAgentThought] = message.agent_thoughts | |||
| if conversation.from_end_user_id: | |||
| # only select name field | |||
| executor = db.session.query(EndUser, EndUser.name).filter( | |||
| EndUser.id == conversation.from_end_user_id | |||
| ).first() | |||
| executor = ( | |||
| db.session.query(EndUser, EndUser.name).filter(EndUser.id == conversation.from_end_user_id).first() | |||
| ) | |||
| else: | |||
| executor = db.session.query(Account, Account.name).filter( | |||
| Account.id == conversation.from_account_id | |||
| ).first() | |||
| executor = ( | |||
| db.session.query(Account, Account.name).filter(Account.id == conversation.from_account_id).first() | |||
| ) | |||
| if executor: | |||
| executor = executor.name | |||
| else: | |||
| executor = 'Unknown' | |||
| executor = "Unknown" | |||
| timezone = pytz.timezone(current_user.timezone) | |||
| result = { | |||
| 'meta': { | |||
| 'status': 'success', | |||
| 'executor': executor, | |||
| 'start_time': message.created_at.astimezone(timezone).isoformat(), | |||
| 'elapsed_time': message.provider_response_latency, | |||
| 'total_tokens': message.answer_tokens + message.message_tokens, | |||
| 'agent_mode': app_model.app_model_config.agent_mode_dict.get('strategy', 'react'), | |||
| 'iterations': len(agent_thoughts), | |||
| "meta": { | |||
| "status": "success", | |||
| "executor": executor, | |||
| "start_time": message.created_at.astimezone(timezone).isoformat(), | |||
| "elapsed_time": message.provider_response_latency, | |||
| "total_tokens": message.answer_tokens + message.message_tokens, | |||
| "agent_mode": app_model.app_model_config.agent_mode_dict.get("strategy", "react"), | |||
| "iterations": len(agent_thoughts), | |||
| }, | |||
| 'iterations': [], | |||
| 'files': message.files, | |||
| "iterations": [], | |||
| "files": message.files, | |||
| } | |||
| agent_config = AgentConfigManager.convert(app_model.app_model_config.to_dict()) | |||
| @@ -86,12 +92,12 @@ class AgentService: | |||
| tool_input = tool_inputs.get(tool_name, {}) | |||
| tool_output = tool_outputs.get(tool_name, {}) | |||
| tool_meta_data = tool_meta.get(tool_name, {}) | |||
| tool_config = tool_meta_data.get('tool_config', {}) | |||
| if tool_config.get('tool_provider_type', '') != 'dataset-retrieval': | |||
| tool_config = tool_meta_data.get("tool_config", {}) | |||
| if tool_config.get("tool_provider_type", "") != "dataset-retrieval": | |||
| tool_icon = ToolManager.get_tool_icon( | |||
| tenant_id=app_model.tenant_id, | |||
| provider_type=tool_config.get('tool_provider_type', ''), | |||
| provider_id=tool_config.get('tool_provider', ''), | |||
| provider_type=tool_config.get("tool_provider_type", ""), | |||
| provider_id=tool_config.get("tool_provider", ""), | |||
| ) | |||
| if not tool_icon: | |||
| tool_entity = find_agent_tool(tool_name) | |||
| @@ -102,30 +108,34 @@ class AgentService: | |||
| provider_id=tool_entity.provider_id, | |||
| ) | |||
| else: | |||
| tool_icon = '' | |||
| tool_calls.append({ | |||
| 'status': 'success' if not tool_meta_data.get('error') else 'error', | |||
| 'error': tool_meta_data.get('error'), | |||
| 'time_cost': tool_meta_data.get('time_cost', 0), | |||
| 'tool_name': tool_name, | |||
| 'tool_label': tool_label, | |||
| 'tool_input': tool_input, | |||
| 'tool_output': tool_output, | |||
| 'tool_parameters': tool_meta_data.get('tool_parameters', {}), | |||
| 'tool_icon': tool_icon, | |||
| }) | |||
| result['iterations'].append({ | |||
| 'tokens': agent_thought.tokens, | |||
| 'tool_calls': tool_calls, | |||
| 'tool_raw': { | |||
| 'inputs': agent_thought.tool_input, | |||
| 'outputs': agent_thought.observation, | |||
| }, | |||
| 'thought': agent_thought.thought, | |||
| 'created_at': agent_thought.created_at.isoformat(), | |||
| 'files': agent_thought.files, | |||
| }) | |||
| return result | |||
| tool_icon = "" | |||
| tool_calls.append( | |||
| { | |||
| "status": "success" if not tool_meta_data.get("error") else "error", | |||
| "error": tool_meta_data.get("error"), | |||
| "time_cost": tool_meta_data.get("time_cost", 0), | |||
| "tool_name": tool_name, | |||
| "tool_label": tool_label, | |||
| "tool_input": tool_input, | |||
| "tool_output": tool_output, | |||
| "tool_parameters": tool_meta_data.get("tool_parameters", {}), | |||
| "tool_icon": tool_icon, | |||
| } | |||
| ) | |||
| result["iterations"].append( | |||
| { | |||
| "tokens": agent_thought.tokens, | |||
| "tool_calls": tool_calls, | |||
| "tool_raw": { | |||
| "inputs": agent_thought.tool_input, | |||
| "outputs": agent_thought.observation, | |||
| }, | |||
| "thought": agent_thought.thought, | |||
| "created_at": agent_thought.created_at.isoformat(), | |||
| "files": agent_thought.files, | |||
| } | |||
| ) | |||
| return result | |||
| @@ -23,21 +23,18 @@ class AppAnnotationService: | |||
| @classmethod | |||
| def up_insert_app_annotation_from_message(cls, args: dict, app_id: str) -> MessageAnnotation: | |||
| # get app info | |||
| app = db.session.query(App).filter( | |||
| App.id == app_id, | |||
| App.tenant_id == current_user.current_tenant_id, | |||
| App.status == 'normal' | |||
| ).first() | |||
| app = ( | |||
| db.session.query(App) | |||
| .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") | |||
| .first() | |||
| ) | |||
| if not app: | |||
| raise NotFound("App not found") | |||
| if args.get('message_id'): | |||
| message_id = str(args['message_id']) | |||
| if args.get("message_id"): | |||
| message_id = str(args["message_id"]) | |||
| # get message info | |||
| message = db.session.query(Message).filter( | |||
| Message.id == message_id, | |||
| Message.app_id == app.id | |||
| ).first() | |||
| message = db.session.query(Message).filter(Message.id == message_id, Message.app_id == app.id).first() | |||
| if not message: | |||
| raise NotFound("Message Not Exists.") | |||
| @@ -45,159 +42,166 @@ class AppAnnotationService: | |||
| annotation = message.annotation | |||
| # save the message annotation | |||
| if annotation: | |||
| annotation.content = args['answer'] | |||
| annotation.question = args['question'] | |||
| annotation.content = args["answer"] | |||
| annotation.question = args["question"] | |||
| else: | |||
| annotation = MessageAnnotation( | |||
| app_id=app.id, | |||
| conversation_id=message.conversation_id, | |||
| message_id=message.id, | |||
| content=args['answer'], | |||
| question=args['question'], | |||
| account_id=current_user.id | |||
| content=args["answer"], | |||
| question=args["question"], | |||
| account_id=current_user.id, | |||
| ) | |||
| else: | |||
| annotation = MessageAnnotation( | |||
| app_id=app.id, | |||
| content=args['answer'], | |||
| question=args['question'], | |||
| account_id=current_user.id | |||
| app_id=app.id, content=args["answer"], question=args["question"], account_id=current_user.id | |||
| ) | |||
| db.session.add(annotation) | |||
| db.session.commit() | |||
| # if annotation reply is enabled , add annotation to index | |||
| annotation_setting = db.session.query(AppAnnotationSetting).filter( | |||
| AppAnnotationSetting.app_id == app_id).first() | |||
| annotation_setting = ( | |||
| db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app_id).first() | |||
| ) | |||
| if annotation_setting: | |||
| add_annotation_to_index_task.delay(annotation.id, args['question'], current_user.current_tenant_id, | |||
| app_id, annotation_setting.collection_binding_id) | |||
| add_annotation_to_index_task.delay( | |||
| annotation.id, | |||
| args["question"], | |||
| current_user.current_tenant_id, | |||
| app_id, | |||
| annotation_setting.collection_binding_id, | |||
| ) | |||
| return annotation | |||
| @classmethod | |||
| def enable_app_annotation(cls, args: dict, app_id: str) -> dict: | |||
| enable_app_annotation_key = 'enable_app_annotation_{}'.format(str(app_id)) | |||
| enable_app_annotation_key = "enable_app_annotation_{}".format(str(app_id)) | |||
| cache_result = redis_client.get(enable_app_annotation_key) | |||
| if cache_result is not None: | |||
| return { | |||
| 'job_id': cache_result, | |||
| 'job_status': 'processing' | |||
| } | |||
| return {"job_id": cache_result, "job_status": "processing"} | |||
| # async job | |||
| job_id = str(uuid.uuid4()) | |||
| enable_app_annotation_job_key = 'enable_app_annotation_job_{}'.format(str(job_id)) | |||
| enable_app_annotation_job_key = "enable_app_annotation_job_{}".format(str(job_id)) | |||
| # send batch add segments task | |||
| redis_client.setnx(enable_app_annotation_job_key, 'waiting') | |||
| enable_annotation_reply_task.delay(str(job_id), app_id, current_user.id, current_user.current_tenant_id, | |||
| args['score_threshold'], | |||
| args['embedding_provider_name'], args['embedding_model_name']) | |||
| return { | |||
| 'job_id': job_id, | |||
| 'job_status': 'waiting' | |||
| } | |||
| redis_client.setnx(enable_app_annotation_job_key, "waiting") | |||
| enable_annotation_reply_task.delay( | |||
| str(job_id), | |||
| app_id, | |||
| current_user.id, | |||
| current_user.current_tenant_id, | |||
| args["score_threshold"], | |||
| args["embedding_provider_name"], | |||
| args["embedding_model_name"], | |||
| ) | |||
| return {"job_id": job_id, "job_status": "waiting"} | |||
| @classmethod | |||
| def disable_app_annotation(cls, app_id: str) -> dict: | |||
| disable_app_annotation_key = 'disable_app_annotation_{}'.format(str(app_id)) | |||
| disable_app_annotation_key = "disable_app_annotation_{}".format(str(app_id)) | |||
| cache_result = redis_client.get(disable_app_annotation_key) | |||
| if cache_result is not None: | |||
| return { | |||
| 'job_id': cache_result, | |||
| 'job_status': 'processing' | |||
| } | |||
| return {"job_id": cache_result, "job_status": "processing"} | |||
| # async job | |||
| job_id = str(uuid.uuid4()) | |||
| disable_app_annotation_job_key = 'disable_app_annotation_job_{}'.format(str(job_id)) | |||
| disable_app_annotation_job_key = "disable_app_annotation_job_{}".format(str(job_id)) | |||
| # send batch add segments task | |||
| redis_client.setnx(disable_app_annotation_job_key, 'waiting') | |||
| redis_client.setnx(disable_app_annotation_job_key, "waiting") | |||
| disable_annotation_reply_task.delay(str(job_id), app_id, current_user.current_tenant_id) | |||
| return { | |||
| 'job_id': job_id, | |||
| 'job_status': 'waiting' | |||
| } | |||
| return {"job_id": job_id, "job_status": "waiting"} | |||
| @classmethod | |||
| def get_annotation_list_by_app_id(cls, app_id: str, page: int, limit: int, keyword: str): | |||
| # get app info | |||
| app = db.session.query(App).filter( | |||
| App.id == app_id, | |||
| App.tenant_id == current_user.current_tenant_id, | |||
| App.status == 'normal' | |||
| ).first() | |||
| app = ( | |||
| db.session.query(App) | |||
| .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") | |||
| .first() | |||
| ) | |||
| if not app: | |||
| raise NotFound("App not found") | |||
| if keyword: | |||
| annotations = (db.session.query(MessageAnnotation) | |||
| .filter(MessageAnnotation.app_id == app_id) | |||
| .filter( | |||
| or_( | |||
| MessageAnnotation.question.ilike('%{}%'.format(keyword)), | |||
| MessageAnnotation.content.ilike('%{}%'.format(keyword)) | |||
| annotations = ( | |||
| db.session.query(MessageAnnotation) | |||
| .filter(MessageAnnotation.app_id == app_id) | |||
| .filter( | |||
| or_( | |||
| MessageAnnotation.question.ilike("%{}%".format(keyword)), | |||
| MessageAnnotation.content.ilike("%{}%".format(keyword)), | |||
| ) | |||
| ) | |||
| .order_by(MessageAnnotation.created_at.desc()) | |||
| .paginate(page=page, per_page=limit, max_per_page=100, error_out=False) | |||
| ) | |||
| .order_by(MessageAnnotation.created_at.desc()) | |||
| .paginate(page=page, per_page=limit, max_per_page=100, error_out=False)) | |||
| else: | |||
| annotations = (db.session.query(MessageAnnotation) | |||
| .filter(MessageAnnotation.app_id == app_id) | |||
| .order_by(MessageAnnotation.created_at.desc()) | |||
| .paginate(page=page, per_page=limit, max_per_page=100, error_out=False)) | |||
| annotations = ( | |||
| db.session.query(MessageAnnotation) | |||
| .filter(MessageAnnotation.app_id == app_id) | |||
| .order_by(MessageAnnotation.created_at.desc()) | |||
| .paginate(page=page, per_page=limit, max_per_page=100, error_out=False) | |||
| ) | |||
| return annotations.items, annotations.total | |||
| @classmethod | |||
| def export_annotation_list_by_app_id(cls, app_id: str): | |||
| # get app info | |||
| app = db.session.query(App).filter( | |||
| App.id == app_id, | |||
| App.tenant_id == current_user.current_tenant_id, | |||
| App.status == 'normal' | |||
| ).first() | |||
| app = ( | |||
| db.session.query(App) | |||
| .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") | |||
| .first() | |||
| ) | |||
| if not app: | |||
| raise NotFound("App not found") | |||
| annotations = (db.session.query(MessageAnnotation) | |||
| .filter(MessageAnnotation.app_id == app_id) | |||
| .order_by(MessageAnnotation.created_at.desc()).all()) | |||
| annotations = ( | |||
| db.session.query(MessageAnnotation) | |||
| .filter(MessageAnnotation.app_id == app_id) | |||
| .order_by(MessageAnnotation.created_at.desc()) | |||
| .all() | |||
| ) | |||
| return annotations | |||
| @classmethod | |||
| def insert_app_annotation_directly(cls, args: dict, app_id: str) -> MessageAnnotation: | |||
| # get app info | |||
| app = db.session.query(App).filter( | |||
| App.id == app_id, | |||
| App.tenant_id == current_user.current_tenant_id, | |||
| App.status == 'normal' | |||
| ).first() | |||
| app = ( | |||
| db.session.query(App) | |||
| .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") | |||
| .first() | |||
| ) | |||
| if not app: | |||
| raise NotFound("App not found") | |||
| annotation = MessageAnnotation( | |||
| app_id=app.id, | |||
| content=args['answer'], | |||
| question=args['question'], | |||
| account_id=current_user.id | |||
| app_id=app.id, content=args["answer"], question=args["question"], account_id=current_user.id | |||
| ) | |||
| db.session.add(annotation) | |||
| db.session.commit() | |||
| # if annotation reply is enabled , add annotation to index | |||
| annotation_setting = db.session.query(AppAnnotationSetting).filter( | |||
| AppAnnotationSetting.app_id == app_id).first() | |||
| annotation_setting = ( | |||
| db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app_id).first() | |||
| ) | |||
| if annotation_setting: | |||
| add_annotation_to_index_task.delay(annotation.id, args['question'], current_user.current_tenant_id, | |||
| app_id, annotation_setting.collection_binding_id) | |||
| add_annotation_to_index_task.delay( | |||
| annotation.id, | |||
| args["question"], | |||
| current_user.current_tenant_id, | |||
| app_id, | |||
| annotation_setting.collection_binding_id, | |||
| ) | |||
| return annotation | |||
| @classmethod | |||
| def update_app_annotation_directly(cls, args: dict, app_id: str, annotation_id: str): | |||
| # get app info | |||
| app = db.session.query(App).filter( | |||
| App.id == app_id, | |||
| App.tenant_id == current_user.current_tenant_id, | |||
| App.status == 'normal' | |||
| ).first() | |||
| app = ( | |||
| db.session.query(App) | |||
| .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") | |||
| .first() | |||
| ) | |||
| if not app: | |||
| raise NotFound("App not found") | |||
| @@ -207,30 +211,34 @@ class AppAnnotationService: | |||
| if not annotation: | |||
| raise NotFound("Annotation not found") | |||
| annotation.content = args['answer'] | |||
| annotation.question = args['question'] | |||
| annotation.content = args["answer"] | |||
| annotation.question = args["question"] | |||
| db.session.commit() | |||
| # if annotation reply is enabled , add annotation to index | |||
| app_annotation_setting = db.session.query(AppAnnotationSetting).filter( | |||
| AppAnnotationSetting.app_id == app_id | |||
| ).first() | |||
| app_annotation_setting = ( | |||
| db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app_id).first() | |||
| ) | |||
| if app_annotation_setting: | |||
| update_annotation_to_index_task.delay(annotation.id, annotation.question, | |||
| current_user.current_tenant_id, | |||
| app_id, app_annotation_setting.collection_binding_id) | |||
| update_annotation_to_index_task.delay( | |||
| annotation.id, | |||
| annotation.question, | |||
| current_user.current_tenant_id, | |||
| app_id, | |||
| app_annotation_setting.collection_binding_id, | |||
| ) | |||
| return annotation | |||
| @classmethod | |||
| def delete_app_annotation(cls, app_id: str, annotation_id: str): | |||
| # get app info | |||
| app = db.session.query(App).filter( | |||
| App.id == app_id, | |||
| App.tenant_id == current_user.current_tenant_id, | |||
| App.status == 'normal' | |||
| ).first() | |||
| app = ( | |||
| db.session.query(App) | |||
| .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") | |||
| .first() | |||
| ) | |||
| if not app: | |||
| raise NotFound("App not found") | |||
| @@ -242,33 +250,34 @@ class AppAnnotationService: | |||
| db.session.delete(annotation) | |||
| annotation_hit_histories = (db.session.query(AppAnnotationHitHistory) | |||
| .filter(AppAnnotationHitHistory.annotation_id == annotation_id) | |||
| .all() | |||
| ) | |||
| annotation_hit_histories = ( | |||
| db.session.query(AppAnnotationHitHistory) | |||
| .filter(AppAnnotationHitHistory.annotation_id == annotation_id) | |||
| .all() | |||
| ) | |||
| if annotation_hit_histories: | |||
| for annotation_hit_history in annotation_hit_histories: | |||
| db.session.delete(annotation_hit_history) | |||
| db.session.commit() | |||
| # if annotation reply is enabled , delete annotation index | |||
| app_annotation_setting = db.session.query(AppAnnotationSetting).filter( | |||
| AppAnnotationSetting.app_id == app_id | |||
| ).first() | |||
| app_annotation_setting = ( | |||
| db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app_id).first() | |||
| ) | |||
| if app_annotation_setting: | |||
| delete_annotation_index_task.delay(annotation.id, app_id, | |||
| current_user.current_tenant_id, | |||
| app_annotation_setting.collection_binding_id) | |||
| delete_annotation_index_task.delay( | |||
| annotation.id, app_id, current_user.current_tenant_id, app_annotation_setting.collection_binding_id | |||
| ) | |||
| @classmethod | |||
| def batch_import_app_annotations(cls, app_id, file: FileStorage) -> dict: | |||
| # get app info | |||
| app = db.session.query(App).filter( | |||
| App.id == app_id, | |||
| App.tenant_id == current_user.current_tenant_id, | |||
| App.status == 'normal' | |||
| ).first() | |||
| app = ( | |||
| db.session.query(App) | |||
| .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") | |||
| .first() | |||
| ) | |||
| if not app: | |||
| raise NotFound("App not found") | |||
| @@ -278,10 +287,7 @@ class AppAnnotationService: | |||
| df = pd.read_csv(file) | |||
| result = [] | |||
| for index, row in df.iterrows(): | |||
| content = { | |||
| 'question': row[0], | |||
| 'answer': row[1] | |||
| } | |||
| content = {"question": row[0], "answer": row[1]} | |||
| result.append(content) | |||
| if len(result) == 0: | |||
| raise ValueError("The CSV file is empty.") | |||
| @@ -293,28 +299,24 @@ class AppAnnotationService: | |||
| raise ValueError("The number of annotations exceeds the limit of your subscription.") | |||
| # async job | |||
| job_id = str(uuid.uuid4()) | |||
| indexing_cache_key = 'app_annotation_batch_import_{}'.format(str(job_id)) | |||
| indexing_cache_key = "app_annotation_batch_import_{}".format(str(job_id)) | |||
| # send batch add segments task | |||
| redis_client.setnx(indexing_cache_key, 'waiting') | |||
| batch_import_annotations_task.delay(str(job_id), result, app_id, | |||
| current_user.current_tenant_id, current_user.id) | |||
| redis_client.setnx(indexing_cache_key, "waiting") | |||
| batch_import_annotations_task.delay( | |||
| str(job_id), result, app_id, current_user.current_tenant_id, current_user.id | |||
| ) | |||
| except Exception as e: | |||
| return { | |||
| 'error_msg': str(e) | |||
| } | |||
| return { | |||
| 'job_id': job_id, | |||
| 'job_status': 'waiting' | |||
| } | |||
| return {"error_msg": str(e)} | |||
| return {"job_id": job_id, "job_status": "waiting"} | |||
| @classmethod | |||
| def get_annotation_hit_histories(cls, app_id: str, annotation_id: str, page, limit): | |||
| # get app info | |||
| app = db.session.query(App).filter( | |||
| App.id == app_id, | |||
| App.tenant_id == current_user.current_tenant_id, | |||
| App.status == 'normal' | |||
| ).first() | |||
| app = ( | |||
| db.session.query(App) | |||
| .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") | |||
| .first() | |||
| ) | |||
| if not app: | |||
| raise NotFound("App not found") | |||
| @@ -324,12 +326,15 @@ class AppAnnotationService: | |||
| if not annotation: | |||
| raise NotFound("Annotation not found") | |||
| annotation_hit_histories = (db.session.query(AppAnnotationHitHistory) | |||
| .filter(AppAnnotationHitHistory.app_id == app_id, | |||
| AppAnnotationHitHistory.annotation_id == annotation_id, | |||
| ) | |||
| .order_by(AppAnnotationHitHistory.created_at.desc()) | |||
| .paginate(page=page, per_page=limit, max_per_page=100, error_out=False)) | |||
| annotation_hit_histories = ( | |||
| db.session.query(AppAnnotationHitHistory) | |||
| .filter( | |||
| AppAnnotationHitHistory.app_id == app_id, | |||
| AppAnnotationHitHistory.annotation_id == annotation_id, | |||
| ) | |||
| .order_by(AppAnnotationHitHistory.created_at.desc()) | |||
| .paginate(page=page, per_page=limit, max_per_page=100, error_out=False) | |||
| ) | |||
| return annotation_hit_histories.items, annotation_hit_histories.total | |||
| @classmethod | |||
| @@ -341,15 +346,21 @@ class AppAnnotationService: | |||
| return annotation | |||
| @classmethod | |||
| def add_annotation_history(cls, annotation_id: str, app_id: str, annotation_question: str, | |||
| annotation_content: str, query: str, user_id: str, | |||
| message_id: str, from_source: str, score: float): | |||
| def add_annotation_history( | |||
| cls, | |||
| annotation_id: str, | |||
| app_id: str, | |||
| annotation_question: str, | |||
| annotation_content: str, | |||
| query: str, | |||
| user_id: str, | |||
| message_id: str, | |||
| from_source: str, | |||
| score: float, | |||
| ): | |||
| # add hit count to annotation | |||
| db.session.query(MessageAnnotation).filter( | |||
| MessageAnnotation.id == annotation_id | |||
| ).update( | |||
| {MessageAnnotation.hit_count: MessageAnnotation.hit_count + 1}, | |||
| synchronize_session=False | |||
| db.session.query(MessageAnnotation).filter(MessageAnnotation.id == annotation_id).update( | |||
| {MessageAnnotation.hit_count: MessageAnnotation.hit_count + 1}, synchronize_session=False | |||
| ) | |||
| annotation_hit_history = AppAnnotationHitHistory( | |||
| @@ -361,7 +372,7 @@ class AppAnnotationService: | |||
| score=score, | |||
| message_id=message_id, | |||
| annotation_question=annotation_question, | |||
| annotation_content=annotation_content | |||
| annotation_content=annotation_content, | |||
| ) | |||
| db.session.add(annotation_hit_history) | |||
| db.session.commit() | |||
| @@ -369,17 +380,18 @@ class AppAnnotationService: | |||
| @classmethod | |||
| def get_app_annotation_setting_by_app_id(cls, app_id: str): | |||
| # get app info | |||
| app = db.session.query(App).filter( | |||
| App.id == app_id, | |||
| App.tenant_id == current_user.current_tenant_id, | |||
| App.status == 'normal' | |||
| ).first() | |||
| app = ( | |||
| db.session.query(App) | |||
| .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") | |||
| .first() | |||
| ) | |||
| if not app: | |||
| raise NotFound("App not found") | |||
| annotation_setting = db.session.query(AppAnnotationSetting).filter( | |||
| AppAnnotationSetting.app_id == app_id).first() | |||
| annotation_setting = ( | |||
| db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app_id).first() | |||
| ) | |||
| if annotation_setting: | |||
| collection_binding_detail = annotation_setting.collection_binding_detail | |||
| return { | |||
| @@ -388,32 +400,34 @@ class AppAnnotationService: | |||
| "score_threshold": annotation_setting.score_threshold, | |||
| "embedding_model": { | |||
| "embedding_provider_name": collection_binding_detail.provider_name, | |||
| "embedding_model_name": collection_binding_detail.model_name | |||
| } | |||
| "embedding_model_name": collection_binding_detail.model_name, | |||
| }, | |||
| } | |||
| return { | |||
| "enabled": False | |||
| } | |||
| return {"enabled": False} | |||
| @classmethod | |||
| def update_app_annotation_setting(cls, app_id: str, annotation_setting_id: str, args: dict): | |||
| # get app info | |||
| app = db.session.query(App).filter( | |||
| App.id == app_id, | |||
| App.tenant_id == current_user.current_tenant_id, | |||
| App.status == 'normal' | |||
| ).first() | |||
| app = ( | |||
| db.session.query(App) | |||
| .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") | |||
| .first() | |||
| ) | |||
| if not app: | |||
| raise NotFound("App not found") | |||
| annotation_setting = db.session.query(AppAnnotationSetting).filter( | |||
| AppAnnotationSetting.app_id == app_id, | |||
| AppAnnotationSetting.id == annotation_setting_id, | |||
| ).first() | |||
| annotation_setting = ( | |||
| db.session.query(AppAnnotationSetting) | |||
| .filter( | |||
| AppAnnotationSetting.app_id == app_id, | |||
| AppAnnotationSetting.id == annotation_setting_id, | |||
| ) | |||
| .first() | |||
| ) | |||
| if not annotation_setting: | |||
| raise NotFound("App annotation not found") | |||
| annotation_setting.score_threshold = args['score_threshold'] | |||
| annotation_setting.score_threshold = args["score_threshold"] | |||
| annotation_setting.updated_user_id = current_user.id | |||
| annotation_setting.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) | |||
| db.session.add(annotation_setting) | |||
| @@ -427,6 +441,6 @@ class AppAnnotationService: | |||
| "score_threshold": annotation_setting.score_threshold, | |||
| "embedding_model": { | |||
| "embedding_provider_name": collection_binding_detail.provider_name, | |||
| "embedding_model_name": collection_binding_detail.model_name | |||
| } | |||
| "embedding_model_name": collection_binding_detail.model_name, | |||
| }, | |||
| } | |||
| @@ -5,13 +5,14 @@ from models.api_based_extension import APIBasedExtension, APIBasedExtensionPoint | |||
| class APIBasedExtensionService: | |||
| @staticmethod | |||
| def get_all_by_tenant_id(tenant_id: str) -> list[APIBasedExtension]: | |||
| extension_list = db.session.query(APIBasedExtension) \ | |||
| .filter_by(tenant_id=tenant_id) \ | |||
| .order_by(APIBasedExtension.created_at.desc()) \ | |||
| .all() | |||
| extension_list = ( | |||
| db.session.query(APIBasedExtension) | |||
| .filter_by(tenant_id=tenant_id) | |||
| .order_by(APIBasedExtension.created_at.desc()) | |||
| .all() | |||
| ) | |||
| for extension in extension_list: | |||
| extension.api_key = decrypt_token(extension.tenant_id, extension.api_key) | |||
| @@ -35,10 +36,12 @@ class APIBasedExtensionService: | |||
| @staticmethod | |||
| def get_with_tenant_id(tenant_id: str, api_based_extension_id: str) -> APIBasedExtension: | |||
| extension = db.session.query(APIBasedExtension) \ | |||
| .filter_by(tenant_id=tenant_id) \ | |||
| .filter_by(id=api_based_extension_id) \ | |||
| extension = ( | |||
| db.session.query(APIBasedExtension) | |||
| .filter_by(tenant_id=tenant_id) | |||
| .filter_by(id=api_based_extension_id) | |||
| .first() | |||
| ) | |||
| if not extension: | |||
| raise ValueError("API based extension is not found") | |||
| @@ -55,20 +58,24 @@ class APIBasedExtensionService: | |||
| if not extension_data.id: | |||
| # case one: check new data, name must be unique | |||
| is_name_existed = db.session.query(APIBasedExtension) \ | |||
| .filter_by(tenant_id=extension_data.tenant_id) \ | |||
| .filter_by(name=extension_data.name) \ | |||
| is_name_existed = ( | |||
| db.session.query(APIBasedExtension) | |||
| .filter_by(tenant_id=extension_data.tenant_id) | |||
| .filter_by(name=extension_data.name) | |||
| .first() | |||
| ) | |||
| if is_name_existed: | |||
| raise ValueError("name must be unique, it is already existed") | |||
| else: | |||
| # case two: check existing data, name must be unique | |||
| is_name_existed = db.session.query(APIBasedExtension) \ | |||
| .filter_by(tenant_id=extension_data.tenant_id) \ | |||
| .filter_by(name=extension_data.name) \ | |||
| .filter(APIBasedExtension.id != extension_data.id) \ | |||
| is_name_existed = ( | |||
| db.session.query(APIBasedExtension) | |||
| .filter_by(tenant_id=extension_data.tenant_id) | |||
| .filter_by(name=extension_data.name) | |||
| .filter(APIBasedExtension.id != extension_data.id) | |||
| .first() | |||
| ) | |||
| if is_name_existed: | |||
| raise ValueError("name must be unique, it is already existed") | |||
| @@ -92,7 +99,7 @@ class APIBasedExtensionService: | |||
| try: | |||
| client = APIBasedExtensionRequestor(extension_data.api_endpoint, extension_data.api_key) | |||
| resp = client.request(point=APIBasedExtensionPoint.PING, params={}) | |||
| if resp.get('result') != 'pong': | |||
| if resp.get("result") != "pong": | |||
| raise ValueError(resp) | |||
| except Exception as e: | |||
| raise ValueError("connection error: {}".format(e)) | |||
| @@ -75,43 +75,44 @@ class AppDslService: | |||
| # check or repair dsl version | |||
| import_data = cls._check_or_fix_dsl(import_data) | |||
| app_data = import_data.get('app') | |||
| app_data = import_data.get("app") | |||
| if not app_data: | |||
| raise ValueError("Missing app in data argument") | |||
| # get app basic info | |||
| name = args.get("name") if args.get("name") else app_data.get('name') | |||
| description = args.get("description") if args.get("description") else app_data.get('description', '') | |||
| icon_type = args.get("icon_type") if args.get("icon_type") else app_data.get('icon_type') | |||
| icon = args.get("icon") if args.get("icon") else app_data.get('icon') | |||
| icon_background = args.get("icon_background") if args.get("icon_background") \ | |||
| else app_data.get('icon_background') | |||
| name = args.get("name") if args.get("name") else app_data.get("name") | |||
| description = args.get("description") if args.get("description") else app_data.get("description", "") | |||
| icon_type = args.get("icon_type") if args.get("icon_type") else app_data.get("icon_type") | |||
| icon = args.get("icon") if args.get("icon") else app_data.get("icon") | |||
| icon_background = ( | |||
| args.get("icon_background") if args.get("icon_background") else app_data.get("icon_background") | |||
| ) | |||
| # import dsl and create app | |||
| app_mode = AppMode.value_of(app_data.get('mode')) | |||
| app_mode = AppMode.value_of(app_data.get("mode")) | |||
| if app_mode in [AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]: | |||
| app = cls._import_and_create_new_workflow_based_app( | |||
| tenant_id=tenant_id, | |||
| app_mode=app_mode, | |||
| workflow_data=import_data.get('workflow'), | |||
| workflow_data=import_data.get("workflow"), | |||
| account=account, | |||
| name=name, | |||
| description=description, | |||
| icon_type=icon_type, | |||
| icon=icon, | |||
| icon_background=icon_background | |||
| icon_background=icon_background, | |||
| ) | |||
| elif app_mode in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.COMPLETION]: | |||
| app = cls._import_and_create_new_model_config_based_app( | |||
| tenant_id=tenant_id, | |||
| app_mode=app_mode, | |||
| model_config_data=import_data.get('model_config'), | |||
| model_config_data=import_data.get("model_config"), | |||
| account=account, | |||
| name=name, | |||
| description=description, | |||
| icon_type=icon_type, | |||
| icon=icon, | |||
| icon_background=icon_background | |||
| icon_background=icon_background, | |||
| ) | |||
| else: | |||
| raise ValueError("Invalid app mode") | |||
| @@ -134,27 +135,26 @@ class AppDslService: | |||
| # check or repair dsl version | |||
| import_data = cls._check_or_fix_dsl(import_data) | |||
| app_data = import_data.get('app') | |||
| app_data = import_data.get("app") | |||
| if not app_data: | |||
| raise ValueError("Missing app in data argument") | |||
| # import dsl and overwrite app | |||
| app_mode = AppMode.value_of(app_data.get('mode')) | |||
| app_mode = AppMode.value_of(app_data.get("mode")) | |||
| if app_mode not in [AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]: | |||
| raise ValueError("Only support import workflow in advanced-chat or workflow app.") | |||
| if app_data.get('mode') != app_model.mode: | |||
| raise ValueError( | |||
| f"App mode {app_data.get('mode')} is not matched with current app mode {app_mode.value}") | |||
| if app_data.get("mode") != app_model.mode: | |||
| raise ValueError(f"App mode {app_data.get('mode')} is not matched with current app mode {app_mode.value}") | |||
| return cls._import_and_overwrite_workflow_based_app( | |||
| app_model=app_model, | |||
| workflow_data=import_data.get('workflow'), | |||
| workflow_data=import_data.get("workflow"), | |||
| account=account, | |||
| ) | |||
| @classmethod | |||
| def export_dsl(cls, app_model: App, include_secret:bool = False) -> str: | |||
| def export_dsl(cls, app_model: App, include_secret: bool = False) -> str: | |||
| """ | |||
| Export app | |||
| :param app_model: App instance | |||
| @@ -168,14 +168,16 @@ class AppDslService: | |||
| "app": { | |||
| "name": app_model.name, | |||
| "mode": app_model.mode, | |||
| "icon": '🤖' if app_model.icon_type == 'image' else app_model.icon, | |||
| "icon_background": '#FFEAD5' if app_model.icon_type == 'image' else app_model.icon_background, | |||
| "description": app_model.description | |||
| } | |||
| "icon": "🤖" if app_model.icon_type == "image" else app_model.icon, | |||
| "icon_background": "#FFEAD5" if app_model.icon_type == "image" else app_model.icon_background, | |||
| "description": app_model.description, | |||
| }, | |||
| } | |||
| if app_mode in [AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]: | |||
| cls._append_workflow_export_data(export_data=export_data, app_model=app_model, include_secret=include_secret) | |||
| cls._append_workflow_export_data( | |||
| export_data=export_data, app_model=app_model, include_secret=include_secret | |||
| ) | |||
| else: | |||
| cls._append_model_config_export_data(export_data, app_model) | |||
| @@ -188,31 +190,35 @@ class AppDslService: | |||
| :param import_data: import data | |||
| """ | |||
| if not import_data.get('version'): | |||
| import_data['version'] = "0.1.0" | |||
| if not import_data.get("version"): | |||
| import_data["version"] = "0.1.0" | |||
| if not import_data.get('kind') or import_data.get('kind') != "app": | |||
| import_data['kind'] = "app" | |||
| if not import_data.get("kind") or import_data.get("kind") != "app": | |||
| import_data["kind"] = "app" | |||
| if import_data.get('version') != current_dsl_version: | |||
| if import_data.get("version") != current_dsl_version: | |||
| # Currently only one DSL version, so no difference checks or compatibility fixes will be performed. | |||
| logger.warning(f"DSL version {import_data.get('version')} is not compatible " | |||
| f"with current version {current_dsl_version}, related to " | |||
| f"Dify version {dsl_to_dify_version_mapping.get(current_dsl_version)}.") | |||
| logger.warning( | |||
| f"DSL version {import_data.get('version')} is not compatible " | |||
| f"with current version {current_dsl_version}, related to " | |||
| f"Dify version {dsl_to_dify_version_mapping.get(current_dsl_version)}." | |||
| ) | |||
| return import_data | |||
| @classmethod | |||
| def _import_and_create_new_workflow_based_app(cls, | |||
| tenant_id: str, | |||
| app_mode: AppMode, | |||
| workflow_data: dict, | |||
| account: Account, | |||
| name: str, | |||
| description: str, | |||
| icon_type: str, | |||
| icon: str, | |||
| icon_background: str) -> App: | |||
| def _import_and_create_new_workflow_based_app( | |||
| cls, | |||
| tenant_id: str, | |||
| app_mode: AppMode, | |||
| workflow_data: dict, | |||
| account: Account, | |||
| name: str, | |||
| description: str, | |||
| icon_type: str, | |||
| icon: str, | |||
| icon_background: str, | |||
| ) -> App: | |||
| """ | |||
| Import app dsl and create new workflow based app | |||
| @@ -227,8 +233,7 @@ class AppDslService: | |||
| :param icon_background: app icon background | |||
| """ | |||
| if not workflow_data: | |||
| raise ValueError("Missing workflow in data argument " | |||
| "when app mode is advanced-chat or workflow") | |||
| raise ValueError("Missing workflow in data argument " "when app mode is advanced-chat or workflow") | |||
| app = cls._create_app( | |||
| tenant_id=tenant_id, | |||
| @@ -238,37 +243,32 @@ class AppDslService: | |||
| description=description, | |||
| icon_type=icon_type, | |||
| icon=icon, | |||
| icon_background=icon_background | |||
| icon_background=icon_background, | |||
| ) | |||
| # init draft workflow | |||
| environment_variables_list = workflow_data.get('environment_variables') or [] | |||
| environment_variables_list = workflow_data.get("environment_variables") or [] | |||
| environment_variables = [factory.build_variable_from_mapping(obj) for obj in environment_variables_list] | |||
| conversation_variables_list = workflow_data.get('conversation_variables') or [] | |||
| conversation_variables_list = workflow_data.get("conversation_variables") or [] | |||
| conversation_variables = [factory.build_variable_from_mapping(obj) for obj in conversation_variables_list] | |||
| workflow_service = WorkflowService() | |||
| draft_workflow = workflow_service.sync_draft_workflow( | |||
| app_model=app, | |||
| graph=workflow_data.get('graph', {}), | |||
| features=workflow_data.get('../core/app/features', {}), | |||
| graph=workflow_data.get("graph", {}), | |||
| features=workflow_data.get("../core/app/features", {}), | |||
| unique_hash=None, | |||
| account=account, | |||
| environment_variables=environment_variables, | |||
| conversation_variables=conversation_variables, | |||
| ) | |||
| workflow_service.publish_workflow( | |||
| app_model=app, | |||
| account=account, | |||
| draft_workflow=draft_workflow | |||
| ) | |||
| workflow_service.publish_workflow(app_model=app, account=account, draft_workflow=draft_workflow) | |||
| return app | |||
| @classmethod | |||
| def _import_and_overwrite_workflow_based_app(cls, | |||
| app_model: App, | |||
| workflow_data: dict, | |||
| account: Account) -> Workflow: | |||
| def _import_and_overwrite_workflow_based_app( | |||
| cls, app_model: App, workflow_data: dict, account: Account | |||
| ) -> Workflow: | |||
| """ | |||
| Import app dsl and overwrite workflow based app | |||
| @@ -277,8 +277,7 @@ class AppDslService: | |||
| :param account: Account instance | |||
| """ | |||
| if not workflow_data: | |||
| raise ValueError("Missing workflow in data argument " | |||
| "when app mode is advanced-chat or workflow") | |||
| raise ValueError("Missing workflow in data argument " "when app mode is advanced-chat or workflow") | |||
| # fetch draft workflow by app_model | |||
| workflow_service = WorkflowService() | |||
| @@ -289,14 +288,14 @@ class AppDslService: | |||
| unique_hash = None | |||
| # sync draft workflow | |||
| environment_variables_list = workflow_data.get('environment_variables') or [] | |||
| environment_variables_list = workflow_data.get("environment_variables") or [] | |||
| environment_variables = [factory.build_variable_from_mapping(obj) for obj in environment_variables_list] | |||
| conversation_variables_list = workflow_data.get('conversation_variables') or [] | |||
| conversation_variables_list = workflow_data.get("conversation_variables") or [] | |||
| conversation_variables = [factory.build_variable_from_mapping(obj) for obj in conversation_variables_list] | |||
| draft_workflow = workflow_service.sync_draft_workflow( | |||
| app_model=app_model, | |||
| graph=workflow_data.get('graph', {}), | |||
| features=workflow_data.get('features', {}), | |||
| graph=workflow_data.get("graph", {}), | |||
| features=workflow_data.get("features", {}), | |||
| unique_hash=unique_hash, | |||
| account=account, | |||
| environment_variables=environment_variables, | |||
| @@ -306,16 +305,18 @@ class AppDslService: | |||
| return draft_workflow | |||
| @classmethod | |||
| def _import_and_create_new_model_config_based_app(cls, | |||
| tenant_id: str, | |||
| app_mode: AppMode, | |||
| model_config_data: dict, | |||
| account: Account, | |||
| name: str, | |||
| description: str, | |||
| icon_type: str, | |||
| icon: str, | |||
| icon_background: str) -> App: | |||
| def _import_and_create_new_model_config_based_app( | |||
| cls, | |||
| tenant_id: str, | |||
| app_mode: AppMode, | |||
| model_config_data: dict, | |||
| account: Account, | |||
| name: str, | |||
| description: str, | |||
| icon_type: str, | |||
| icon: str, | |||
| icon_background: str, | |||
| ) -> App: | |||
| """ | |||
| Import app dsl and create new model config based app | |||
| @@ -329,8 +330,7 @@ class AppDslService: | |||
| :param icon_background: app icon background | |||
| """ | |||
| if not model_config_data: | |||
| raise ValueError("Missing model_config in data argument " | |||
| "when app mode is chat, agent-chat or completion") | |||
| raise ValueError("Missing model_config in data argument " "when app mode is chat, agent-chat or completion") | |||
| app = cls._create_app( | |||
| tenant_id=tenant_id, | |||
| @@ -340,7 +340,7 @@ class AppDslService: | |||
| description=description, | |||
| icon_type=icon_type, | |||
| icon=icon, | |||
| icon_background=icon_background | |||
| icon_background=icon_background, | |||
| ) | |||
| app_model_config = AppModelConfig() | |||
| @@ -352,23 +352,22 @@ class AppDslService: | |||
| app.app_model_config_id = app_model_config.id | |||
| app_model_config_was_updated.send( | |||
| app, | |||
| app_model_config=app_model_config | |||
| ) | |||
| app_model_config_was_updated.send(app, app_model_config=app_model_config) | |||
| return app | |||
| @classmethod | |||
| def _create_app(cls, | |||
| tenant_id: str, | |||
| app_mode: AppMode, | |||
| account: Account, | |||
| name: str, | |||
| description: str, | |||
| icon_type: str, | |||
| icon: str, | |||
| icon_background: str) -> App: | |||
| def _create_app( | |||
| cls, | |||
| tenant_id: str, | |||
| app_mode: AppMode, | |||
| account: Account, | |||
| name: str, | |||
| description: str, | |||
| icon_type: str, | |||
| icon: str, | |||
| icon_background: str, | |||
| ) -> App: | |||
| """ | |||
| Create new app | |||
| @@ -390,7 +389,7 @@ class AppDslService: | |||
| icon=icon, | |||
| icon_background=icon_background, | |||
| enable_site=True, | |||
| enable_api=True | |||
| enable_api=True, | |||
| ) | |||
| db.session.add(app) | |||
| @@ -412,7 +411,7 @@ class AppDslService: | |||
| if not workflow: | |||
| raise ValueError("Missing draft workflow configuration, please check.") | |||
| export_data['workflow'] = workflow.to_dict(include_secret=include_secret) | |||
| export_data["workflow"] = workflow.to_dict(include_secret=include_secret) | |||
| @classmethod | |||
| def _append_model_config_export_data(cls, export_data: dict, app_model: App) -> None: | |||
| @@ -425,4 +424,4 @@ class AppDslService: | |||
| if not app_model_config: | |||
| raise ValueError("Missing app configuration, please check.") | |||
| export_data['model_config'] = app_model_config.to_dict() | |||
| export_data["model_config"] = app_model_config.to_dict() | |||
| @@ -14,14 +14,15 @@ from services.workflow_service import WorkflowService | |||
| class AppGenerateService: | |||
| @classmethod | |||
| def generate(cls, app_model: App, | |||
| user: Union[Account, EndUser], | |||
| args: Any, | |||
| invoke_from: InvokeFrom, | |||
| streaming: bool = True, | |||
| ): | |||
| def generate( | |||
| cls, | |||
| app_model: App, | |||
| user: Union[Account, EndUser], | |||
| args: Any, | |||
| invoke_from: InvokeFrom, | |||
| streaming: bool = True, | |||
| ): | |||
| """ | |||
| App Content Generate | |||
| :param app_model: app model | |||
| @@ -37,51 +38,54 @@ class AppGenerateService: | |||
| try: | |||
| request_id = rate_limit.enter(request_id) | |||
| if app_model.mode == AppMode.COMPLETION.value: | |||
| return rate_limit.generate(CompletionAppGenerator().generate( | |||
| app_model=app_model, | |||
| user=user, | |||
| args=args, | |||
| invoke_from=invoke_from, | |||
| stream=streaming | |||
| ), request_id) | |||
| return rate_limit.generate( | |||
| CompletionAppGenerator().generate( | |||
| app_model=app_model, user=user, args=args, invoke_from=invoke_from, stream=streaming | |||
| ), | |||
| request_id, | |||
| ) | |||
| elif app_model.mode == AppMode.AGENT_CHAT.value or app_model.is_agent: | |||
| return rate_limit.generate(AgentChatAppGenerator().generate( | |||
| app_model=app_model, | |||
| user=user, | |||
| args=args, | |||
| invoke_from=invoke_from, | |||
| stream=streaming | |||
| ), request_id) | |||
| return rate_limit.generate( | |||
| AgentChatAppGenerator().generate( | |||
| app_model=app_model, user=user, args=args, invoke_from=invoke_from, stream=streaming | |||
| ), | |||
| request_id, | |||
| ) | |||
| elif app_model.mode == AppMode.CHAT.value: | |||
| return rate_limit.generate(ChatAppGenerator().generate( | |||
| app_model=app_model, | |||
| user=user, | |||
| args=args, | |||
| invoke_from=invoke_from, | |||
| stream=streaming | |||
| ), request_id) | |||
| return rate_limit.generate( | |||
| ChatAppGenerator().generate( | |||
| app_model=app_model, user=user, args=args, invoke_from=invoke_from, stream=streaming | |||
| ), | |||
| request_id, | |||
| ) | |||
| elif app_model.mode == AppMode.ADVANCED_CHAT.value: | |||
| workflow = cls._get_workflow(app_model, invoke_from) | |||
| return rate_limit.generate(AdvancedChatAppGenerator().generate( | |||
| app_model=app_model, | |||
| workflow=workflow, | |||
| user=user, | |||
| args=args, | |||
| invoke_from=invoke_from, | |||
| stream=streaming | |||
| ), request_id) | |||
| return rate_limit.generate( | |||
| AdvancedChatAppGenerator().generate( | |||
| app_model=app_model, | |||
| workflow=workflow, | |||
| user=user, | |||
| args=args, | |||
| invoke_from=invoke_from, | |||
| stream=streaming, | |||
| ), | |||
| request_id, | |||
| ) | |||
| elif app_model.mode == AppMode.WORKFLOW.value: | |||
| workflow = cls._get_workflow(app_model, invoke_from) | |||
| return rate_limit.generate(WorkflowAppGenerator().generate( | |||
| app_model=app_model, | |||
| workflow=workflow, | |||
| user=user, | |||
| args=args, | |||
| invoke_from=invoke_from, | |||
| stream=streaming | |||
| ), request_id) | |||
| return rate_limit.generate( | |||
| WorkflowAppGenerator().generate( | |||
| app_model=app_model, | |||
| workflow=workflow, | |||
| user=user, | |||
| args=args, | |||
| invoke_from=invoke_from, | |||
| stream=streaming, | |||
| ), | |||
| request_id, | |||
| ) | |||
| else: | |||
| raise ValueError(f'Invalid app mode {app_model.mode}') | |||
| raise ValueError(f"Invalid app mode {app_model.mode}") | |||
| finally: | |||
| if not streaming: | |||
| rate_limit.exit(request_id) | |||
| @@ -94,38 +98,31 @@ class AppGenerateService: | |||
| return max_active_requests | |||
| @classmethod | |||
| def generate_single_iteration(cls, app_model: App, | |||
| user: Union[Account, EndUser], | |||
| node_id: str, | |||
| args: Any, | |||
| streaming: bool = True): | |||
| def generate_single_iteration( | |||
| cls, app_model: App, user: Union[Account, EndUser], node_id: str, args: Any, streaming: bool = True | |||
| ): | |||
| if app_model.mode == AppMode.ADVANCED_CHAT.value: | |||
| workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER) | |||
| return AdvancedChatAppGenerator().single_iteration_generate( | |||
| app_model=app_model, | |||
| workflow=workflow, | |||
| node_id=node_id, | |||
| user=user, | |||
| args=args, | |||
| stream=streaming | |||
| app_model=app_model, workflow=workflow, node_id=node_id, user=user, args=args, stream=streaming | |||
| ) | |||
| elif app_model.mode == AppMode.WORKFLOW.value: | |||
| workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER) | |||
| return WorkflowAppGenerator().single_iteration_generate( | |||
| app_model=app_model, | |||
| workflow=workflow, | |||
| node_id=node_id, | |||
| user=user, | |||
| args=args, | |||
| stream=streaming | |||
| app_model=app_model, workflow=workflow, node_id=node_id, user=user, args=args, stream=streaming | |||
| ) | |||
| else: | |||
| raise ValueError(f'Invalid app mode {app_model.mode}') | |||
| raise ValueError(f"Invalid app mode {app_model.mode}") | |||
| @classmethod | |||
| def generate_more_like_this(cls, app_model: App, user: Union[Account, EndUser], | |||
| message_id: str, invoke_from: InvokeFrom, streaming: bool = True) \ | |||
| -> Union[dict, Generator]: | |||
| def generate_more_like_this( | |||
| cls, | |||
| app_model: App, | |||
| user: Union[Account, EndUser], | |||
| message_id: str, | |||
| invoke_from: InvokeFrom, | |||
| streaming: bool = True, | |||
| ) -> Union[dict, Generator]: | |||
| """ | |||
| Generate more like this | |||
| :param app_model: app model | |||
| @@ -136,11 +133,7 @@ class AppGenerateService: | |||
| :return: | |||
| """ | |||
| return CompletionAppGenerator().generate_more_like_this( | |||
| app_model=app_model, | |||
| message_id=message_id, | |||
| user=user, | |||
| invoke_from=invoke_from, | |||
| stream=streaming | |||
| app_model=app_model, message_id=message_id, user=user, invoke_from=invoke_from, stream=streaming | |||
| ) | |||
| @classmethod | |||
| @@ -157,12 +150,12 @@ class AppGenerateService: | |||
| workflow = workflow_service.get_draft_workflow(app_model=app_model) | |||
| if not workflow: | |||
| raise ValueError('Workflow not initialized') | |||
| raise ValueError("Workflow not initialized") | |||
| else: | |||
| # fetch published workflow by app_model | |||
| workflow = workflow_service.get_published_workflow(app_model=app_model) | |||
| if not workflow: | |||
| raise ValueError('Workflow not published') | |||
| raise ValueError("Workflow not published") | |||
| return workflow | |||
| @@ -5,7 +5,6 @@ from models.model import AppMode | |||
| class AppModelConfigService: | |||
| @classmethod | |||
| def validate_configuration(cls, tenant_id: str, config: dict, app_mode: AppMode) -> dict: | |||
| if app_mode == AppMode.CHAT: | |||
| @@ -33,27 +33,22 @@ class AppService: | |||
| :param args: request args | |||
| :return: | |||
| """ | |||
| filters = [ | |||
| App.tenant_id == tenant_id, | |||
| App.is_universal == False | |||
| ] | |||
| filters = [App.tenant_id == tenant_id, App.is_universal == False] | |||
| if args['mode'] == 'workflow': | |||
| if args["mode"] == "workflow": | |||
| filters.append(App.mode.in_([AppMode.WORKFLOW.value, AppMode.COMPLETION.value])) | |||
| elif args['mode'] == 'chat': | |||
| elif args["mode"] == "chat": | |||
| filters.append(App.mode.in_([AppMode.CHAT.value, AppMode.ADVANCED_CHAT.value])) | |||
| elif args['mode'] == 'agent-chat': | |||
| elif args["mode"] == "agent-chat": | |||
| filters.append(App.mode == AppMode.AGENT_CHAT.value) | |||
| elif args['mode'] == 'channel': | |||
| elif args["mode"] == "channel": | |||
| filters.append(App.mode == AppMode.CHANNEL.value) | |||
| if args.get('name'): | |||
| name = args['name'][:30] | |||
| filters.append(App.name.ilike(f'%{name}%')) | |||
| if args.get('tag_ids'): | |||
| target_ids = TagService.get_target_ids_by_tag_ids('app', | |||
| tenant_id, | |||
| args['tag_ids']) | |||
| if args.get("name"): | |||
| name = args["name"][:30] | |||
| filters.append(App.name.ilike(f"%{name}%")) | |||
| if args.get("tag_ids"): | |||
| target_ids = TagService.get_target_ids_by_tag_ids("app", tenant_id, args["tag_ids"]) | |||
| if target_ids: | |||
| filters.append(App.id.in_(target_ids)) | |||
| else: | |||
| @@ -61,9 +56,9 @@ class AppService: | |||
| app_models = db.paginate( | |||
| db.select(App).where(*filters).order_by(App.created_at.desc()), | |||
| page=args['page'], | |||
| per_page=args['limit'], | |||
| error_out=False | |||
| page=args["page"], | |||
| per_page=args["limit"], | |||
| error_out=False, | |||
| ) | |||
| return app_models | |||
| @@ -75,21 +70,20 @@ class AppService: | |||
| :param args: request args | |||
| :param account: Account instance | |||
| """ | |||
| app_mode = AppMode.value_of(args['mode']) | |||
| app_mode = AppMode.value_of(args["mode"]) | |||
| app_template = default_app_templates[app_mode] | |||
| # get model config | |||
| default_model_config = app_template.get('model_config') | |||
| default_model_config = app_template.get("model_config") | |||
| default_model_config = default_model_config.copy() if default_model_config else None | |||
| if default_model_config and 'model' in default_model_config: | |||
| if default_model_config and "model" in default_model_config: | |||
| # get model provider | |||
| model_manager = ModelManager() | |||
| # get default model instance | |||
| try: | |||
| model_instance = model_manager.get_default_model_instance( | |||
| tenant_id=account.current_tenant_id, | |||
| model_type=ModelType.LLM | |||
| tenant_id=account.current_tenant_id, model_type=ModelType.LLM | |||
| ) | |||
| except (ProviderTokenNotInitError, LLMBadRequestError): | |||
| model_instance = None | |||
| @@ -98,39 +92,41 @@ class AppService: | |||
| model_instance = None | |||
| if model_instance: | |||
| if model_instance.model == default_model_config['model']['name'] and model_instance.provider == default_model_config['model']['provider']: | |||
| default_model_dict = default_model_config['model'] | |||
| if ( | |||
| model_instance.model == default_model_config["model"]["name"] | |||
| and model_instance.provider == default_model_config["model"]["provider"] | |||
| ): | |||
| default_model_dict = default_model_config["model"] | |||
| else: | |||
| llm_model = cast(LargeLanguageModel, model_instance.model_type_instance) | |||
| model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials) | |||
| default_model_dict = { | |||
| 'provider': model_instance.provider, | |||
| 'name': model_instance.model, | |||
| 'mode': model_schema.model_properties.get(ModelPropertyKey.MODE), | |||
| 'completion_params': {} | |||
| "provider": model_instance.provider, | |||
| "name": model_instance.model, | |||
| "mode": model_schema.model_properties.get(ModelPropertyKey.MODE), | |||
| "completion_params": {}, | |||
| } | |||
| else: | |||
| provider, model = model_manager.get_default_provider_model_name( | |||
| tenant_id=account.current_tenant_id, | |||
| model_type=ModelType.LLM | |||
| tenant_id=account.current_tenant_id, model_type=ModelType.LLM | |||
| ) | |||
| default_model_config['model']['provider'] = provider | |||
| default_model_config['model']['name'] = model | |||
| default_model_dict = default_model_config['model'] | |||
| default_model_config['model'] = json.dumps(default_model_dict) | |||
| app = App(**app_template['app']) | |||
| app.name = args['name'] | |||
| app.description = args.get('description', '') | |||
| app.mode = args['mode'] | |||
| app.icon_type = args.get('icon_type', 'emoji') | |||
| app.icon = args['icon'] | |||
| app.icon_background = args['icon_background'] | |||
| default_model_config["model"]["provider"] = provider | |||
| default_model_config["model"]["name"] = model | |||
| default_model_dict = default_model_config["model"] | |||
| default_model_config["model"] = json.dumps(default_model_dict) | |||
| app = App(**app_template["app"]) | |||
| app.name = args["name"] | |||
| app.description = args.get("description", "") | |||
| app.mode = args["mode"] | |||
| app.icon_type = args.get("icon_type", "emoji") | |||
| app.icon = args["icon"] | |||
| app.icon_background = args["icon_background"] | |||
| app.tenant_id = tenant_id | |||
| app.api_rph = args.get('api_rph', 0) | |||
| app.api_rpm = args.get('api_rpm', 0) | |||
| app.api_rph = args.get("api_rph", 0) | |||
| app.api_rpm = args.get("api_rpm", 0) | |||
| db.session.add(app) | |||
| db.session.flush() | |||
| @@ -158,7 +154,7 @@ class AppService: | |||
| model_config: AppModelConfig = app.app_model_config | |||
| agent_mode = model_config.agent_mode_dict | |||
| # decrypt agent tool parameters if it's secret-input | |||
| for tool in agent_mode.get('tools') or []: | |||
| for tool in agent_mode.get("tools") or []: | |||
| if not isinstance(tool, dict) or len(tool.keys()) <= 3: | |||
| continue | |||
| agent_tool_entity = AgentToolEntity(**tool) | |||
| @@ -174,7 +170,7 @@ class AppService: | |||
| tool_runtime=tool_runtime, | |||
| provider_name=agent_tool_entity.provider_id, | |||
| provider_type=agent_tool_entity.provider_type, | |||
| identity_id=f'AGENT.{app.id}' | |||
| identity_id=f"AGENT.{app.id}", | |||
| ) | |||
| # get decrypted parameters | |||
| @@ -185,7 +181,7 @@ class AppService: | |||
| masked_parameter = {} | |||
| # override tool parameters | |||
| tool['tool_parameters'] = masked_parameter | |||
| tool["tool_parameters"] = masked_parameter | |||
| except Exception as e: | |||
| pass | |||
| @@ -215,12 +211,12 @@ class AppService: | |||
| :param args: request args | |||
| :return: App instance | |||
| """ | |||
| app.name = args.get('name') | |||
| app.description = args.get('description', '') | |||
| app.max_active_requests = args.get('max_active_requests') | |||
| app.icon_type = args.get('icon_type', 'emoji') | |||
| app.icon = args.get('icon') | |||
| app.icon_background = args.get('icon_background') | |||
| app.name = args.get("name") | |||
| app.description = args.get("description", "") | |||
| app.max_active_requests = args.get("max_active_requests") | |||
| app.icon_type = args.get("icon_type", "emoji") | |||
| app.icon = args.get("icon") | |||
| app.icon_background = args.get("icon_background") | |||
| app.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) | |||
| db.session.commit() | |||
| @@ -298,10 +294,7 @@ class AppService: | |||
| db.session.commit() | |||
| # Trigger asynchronous deletion of app and related data | |||
| remove_app_and_related_data_task.delay( | |||
| tenant_id=app.tenant_id, | |||
| app_id=app.id | |||
| ) | |||
| remove_app_and_related_data_task.delay(tenant_id=app.tenant_id, app_id=app.id) | |||
| def get_app_meta(self, app_model: App) -> dict: | |||
| """ | |||
| @@ -311,9 +304,7 @@ class AppService: | |||
| """ | |||
| app_mode = AppMode.value_of(app_model.mode) | |||
| meta = { | |||
| 'tool_icons': {} | |||
| } | |||
| meta = {"tool_icons": {}} | |||
| if app_mode in [AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]: | |||
| workflow = app_model.workflow | |||
| @@ -321,17 +312,19 @@ class AppService: | |||
| return meta | |||
| graph = workflow.graph_dict | |||
| nodes = graph.get('nodes', []) | |||
| nodes = graph.get("nodes", []) | |||
| tools = [] | |||
| for node in nodes: | |||
| if node.get('data', {}).get('type') == 'tool': | |||
| node_data = node.get('data', {}) | |||
| tools.append({ | |||
| 'provider_type': node_data.get('provider_type'), | |||
| 'provider_id': node_data.get('provider_id'), | |||
| 'tool_name': node_data.get('tool_name'), | |||
| 'tool_parameters': {} | |||
| }) | |||
| if node.get("data", {}).get("type") == "tool": | |||
| node_data = node.get("data", {}) | |||
| tools.append( | |||
| { | |||
| "provider_type": node_data.get("provider_type"), | |||
| "provider_id": node_data.get("provider_id"), | |||
| "tool_name": node_data.get("tool_name"), | |||
| "tool_parameters": {}, | |||
| } | |||
| ) | |||
| else: | |||
| app_model_config: AppModelConfig = app_model.app_model_config | |||
| @@ -341,30 +334,26 @@ class AppService: | |||
| agent_config = app_model_config.agent_mode_dict or {} | |||
| # get all tools | |||
| tools = agent_config.get('tools', []) | |||
| tools = agent_config.get("tools", []) | |||
| url_prefix = (dify_config.CONSOLE_API_URL | |||
| + "/console/api/workspaces/current/tool-provider/builtin/") | |||
| url_prefix = dify_config.CONSOLE_API_URL + "/console/api/workspaces/current/tool-provider/builtin/" | |||
| for tool in tools: | |||
| keys = list(tool.keys()) | |||
| if len(keys) >= 4: | |||
| # current tool standard | |||
| provider_type = tool.get('provider_type') | |||
| provider_id = tool.get('provider_id') | |||
| tool_name = tool.get('tool_name') | |||
| if provider_type == 'builtin': | |||
| meta['tool_icons'][tool_name] = url_prefix + provider_id + '/icon' | |||
| elif provider_type == 'api': | |||
| provider_type = tool.get("provider_type") | |||
| provider_id = tool.get("provider_id") | |||
| tool_name = tool.get("tool_name") | |||
| if provider_type == "builtin": | |||
| meta["tool_icons"][tool_name] = url_prefix + provider_id + "/icon" | |||
| elif provider_type == "api": | |||
| try: | |||
| provider: ApiToolProvider = db.session.query(ApiToolProvider).filter( | |||
| ApiToolProvider.id == provider_id | |||
| ).first() | |||
| meta['tool_icons'][tool_name] = json.loads(provider.icon) | |||
| provider: ApiToolProvider = ( | |||
| db.session.query(ApiToolProvider).filter(ApiToolProvider.id == provider_id).first() | |||
| ) | |||
| meta["tool_icons"][tool_name] = json.loads(provider.icon) | |||
| except: | |||
| meta['tool_icons'][tool_name] = { | |||
| "background": "#252525", | |||
| "content": "\ud83d\ude01" | |||
| } | |||
| meta["tool_icons"][tool_name] = {"background": "#252525", "content": "\ud83d\ude01"} | |||
| return meta | |||
| @@ -17,7 +17,7 @@ from services.errors.audio import ( | |||
| FILE_SIZE = 30 | |||
| FILE_SIZE_LIMIT = FILE_SIZE * 1024 * 1024 | |||
| ALLOWED_EXTENSIONS = ['mp3', 'mp4', 'mpeg', 'mpga', 'm4a', 'wav', 'webm', 'amr'] | |||
| ALLOWED_EXTENSIONS = ["mp3", "mp4", "mpeg", "mpga", "m4a", "wav", "webm", "amr"] | |||
| logger = logging.getLogger(__name__) | |||
| @@ -31,19 +31,19 @@ class AudioService: | |||
| raise ValueError("Speech to text is not enabled") | |||
| features_dict = workflow.features_dict | |||
| if 'speech_to_text' not in features_dict or not features_dict['speech_to_text'].get('enabled'): | |||
| if "speech_to_text" not in features_dict or not features_dict["speech_to_text"].get("enabled"): | |||
| raise ValueError("Speech to text is not enabled") | |||
| else: | |||
| app_model_config: AppModelConfig = app_model.app_model_config | |||
| if not app_model_config.speech_to_text_dict['enabled']: | |||
| if not app_model_config.speech_to_text_dict["enabled"]: | |||
| raise ValueError("Speech to text is not enabled") | |||
| if file is None: | |||
| raise NoAudioUploadedServiceError() | |||
| extension = file.mimetype | |||
| if extension not in [f'audio/{ext}' for ext in ALLOWED_EXTENSIONS]: | |||
| if extension not in [f"audio/{ext}" for ext in ALLOWED_EXTENSIONS]: | |||
| raise UnsupportedAudioTypeServiceError() | |||
| file_content = file.read() | |||
| @@ -55,20 +55,25 @@ class AudioService: | |||
| model_manager = ModelManager() | |||
| model_instance = model_manager.get_default_model_instance( | |||
| tenant_id=app_model.tenant_id, | |||
| model_type=ModelType.SPEECH2TEXT | |||
| tenant_id=app_model.tenant_id, model_type=ModelType.SPEECH2TEXT | |||
| ) | |||
| if model_instance is None: | |||
| raise ProviderNotSupportSpeechToTextServiceError() | |||
| buffer = io.BytesIO(file_content) | |||
| buffer.name = 'temp.mp3' | |||
| buffer.name = "temp.mp3" | |||
| return {"text": model_instance.invoke_speech2text(file=buffer, user=end_user)} | |||
| @classmethod | |||
| def transcript_tts(cls, app_model: App, text: Optional[str] = None, | |||
| voice: Optional[str] = None, end_user: Optional[str] = None, message_id: Optional[str] = None): | |||
| def transcript_tts( | |||
| cls, | |||
| app_model: App, | |||
| text: Optional[str] = None, | |||
| voice: Optional[str] = None, | |||
| end_user: Optional[str] = None, | |||
| message_id: Optional[str] = None, | |||
| ): | |||
| from collections.abc import Generator | |||
| from flask import Response, stream_with_context | |||
| @@ -84,65 +89,56 @@ class AudioService: | |||
| raise ValueError("TTS is not enabled") | |||
| features_dict = workflow.features_dict | |||
| if 'text_to_speech' not in features_dict or not features_dict['text_to_speech'].get('enabled'): | |||
| if "text_to_speech" not in features_dict or not features_dict["text_to_speech"].get("enabled"): | |||
| raise ValueError("TTS is not enabled") | |||
| voice = features_dict['text_to_speech'].get('voice') if voice is None else voice | |||
| voice = features_dict["text_to_speech"].get("voice") if voice is None else voice | |||
| else: | |||
| text_to_speech_dict = app_model.app_model_config.text_to_speech_dict | |||
| if not text_to_speech_dict.get('enabled'): | |||
| if not text_to_speech_dict.get("enabled"): | |||
| raise ValueError("TTS is not enabled") | |||
| voice = text_to_speech_dict.get('voice') if voice is None else voice | |||
| voice = text_to_speech_dict.get("voice") if voice is None else voice | |||
| model_manager = ModelManager() | |||
| model_instance = model_manager.get_default_model_instance( | |||
| tenant_id=app_model.tenant_id, | |||
| model_type=ModelType.TTS | |||
| tenant_id=app_model.tenant_id, model_type=ModelType.TTS | |||
| ) | |||
| try: | |||
| if not voice: | |||
| voices = model_instance.get_tts_voices() | |||
| if voices: | |||
| voice = voices[0].get('value') | |||
| voice = voices[0].get("value") | |||
| else: | |||
| raise ValueError("Sorry, no voice available.") | |||
| return model_instance.invoke_tts( | |||
| content_text=text_content.strip(), | |||
| user=end_user, | |||
| tenant_id=app_model.tenant_id, | |||
| voice=voice | |||
| content_text=text_content.strip(), user=end_user, tenant_id=app_model.tenant_id, voice=voice | |||
| ) | |||
| except Exception as e: | |||
| raise e | |||
| if message_id: | |||
| message = db.session.query(Message).filter( | |||
| Message.id == message_id | |||
| ).first() | |||
| if message.answer == '' and message.status == 'normal': | |||
| message = db.session.query(Message).filter(Message.id == message_id).first() | |||
| if message.answer == "" and message.status == "normal": | |||
| return None | |||
| else: | |||
| response = invoke_tts(message.answer, app_model=app_model, voice=voice) | |||
| if isinstance(response, Generator): | |||
| return Response(stream_with_context(response), content_type='audio/mpeg') | |||
| return Response(stream_with_context(response), content_type="audio/mpeg") | |||
| return response | |||
| else: | |||
| response = invoke_tts(text, app_model, voice) | |||
| if isinstance(response, Generator): | |||
| return Response(stream_with_context(response), content_type='audio/mpeg') | |||
| return Response(stream_with_context(response), content_type="audio/mpeg") | |||
| return response | |||
| @classmethod | |||
| def transcript_tts_voices(cls, tenant_id: str, language: str): | |||
| model_manager = ModelManager() | |||
| model_instance = model_manager.get_default_model_instance( | |||
| tenant_id=tenant_id, | |||
| model_type=ModelType.TTS | |||
| ) | |||
| model_instance = model_manager.get_default_model_instance(tenant_id=tenant_id, model_type=ModelType.TTS) | |||
| if model_instance is None: | |||
| raise ProviderNotSupportTextToSpeechServiceError() | |||
| @@ -1,14 +1,12 @@ | |||
| from services.auth.firecrawl import FirecrawlAuth | |||
| class ApiKeyAuthFactory: | |||
| def __init__(self, provider: str, credentials: dict): | |||
| if provider == 'firecrawl': | |||
| if provider == "firecrawl": | |||
| self.auth = FirecrawlAuth(credentials) | |||
| else: | |||
| raise ValueError('Invalid provider') | |||
| raise ValueError("Invalid provider") | |||
| def validate_credentials(self): | |||
| return self.auth.validate_credentials() | |||
| @@ -7,39 +7,43 @@ from services.auth.api_key_auth_factory import ApiKeyAuthFactory | |||
| class ApiKeyAuthService: | |||
| @staticmethod | |||
| def get_provider_auth_list(tenant_id: str) -> list: | |||
| data_source_api_key_bindings = db.session.query(DataSourceApiKeyAuthBinding).filter( | |||
| DataSourceApiKeyAuthBinding.tenant_id == tenant_id, | |||
| DataSourceApiKeyAuthBinding.disabled.is_(False) | |||
| ).all() | |||
| data_source_api_key_bindings = ( | |||
| db.session.query(DataSourceApiKeyAuthBinding) | |||
| .filter(DataSourceApiKeyAuthBinding.tenant_id == tenant_id, DataSourceApiKeyAuthBinding.disabled.is_(False)) | |||
| .all() | |||
| ) | |||
| return data_source_api_key_bindings | |||
| @staticmethod | |||
| def create_provider_auth(tenant_id: str, args: dict): | |||
| auth_result = ApiKeyAuthFactory(args['provider'], args['credentials']).validate_credentials() | |||
| auth_result = ApiKeyAuthFactory(args["provider"], args["credentials"]).validate_credentials() | |||
| if auth_result: | |||
| # Encrypt the api key | |||
| api_key = encrypter.encrypt_token(tenant_id, args['credentials']['config']['api_key']) | |||
| args['credentials']['config']['api_key'] = api_key | |||
| api_key = encrypter.encrypt_token(tenant_id, args["credentials"]["config"]["api_key"]) | |||
| args["credentials"]["config"]["api_key"] = api_key | |||
| data_source_api_key_binding = DataSourceApiKeyAuthBinding() | |||
| data_source_api_key_binding.tenant_id = tenant_id | |||
| data_source_api_key_binding.category = args['category'] | |||
| data_source_api_key_binding.provider = args['provider'] | |||
| data_source_api_key_binding.credentials = json.dumps(args['credentials'], ensure_ascii=False) | |||
| data_source_api_key_binding.category = args["category"] | |||
| data_source_api_key_binding.provider = args["provider"] | |||
| data_source_api_key_binding.credentials = json.dumps(args["credentials"], ensure_ascii=False) | |||
| db.session.add(data_source_api_key_binding) | |||
| db.session.commit() | |||
| @staticmethod | |||
| def get_auth_credentials(tenant_id: str, category: str, provider: str): | |||
| data_source_api_key_bindings = db.session.query(DataSourceApiKeyAuthBinding).filter( | |||
| DataSourceApiKeyAuthBinding.tenant_id == tenant_id, | |||
| DataSourceApiKeyAuthBinding.category == category, | |||
| DataSourceApiKeyAuthBinding.provider == provider, | |||
| DataSourceApiKeyAuthBinding.disabled.is_(False) | |||
| ).first() | |||
| data_source_api_key_bindings = ( | |||
| db.session.query(DataSourceApiKeyAuthBinding) | |||
| .filter( | |||
| DataSourceApiKeyAuthBinding.tenant_id == tenant_id, | |||
| DataSourceApiKeyAuthBinding.category == category, | |||
| DataSourceApiKeyAuthBinding.provider == provider, | |||
| DataSourceApiKeyAuthBinding.disabled.is_(False), | |||
| ) | |||
| .first() | |||
| ) | |||
| if not data_source_api_key_bindings: | |||
| return None | |||
| credentials = json.loads(data_source_api_key_bindings.credentials) | |||
| @@ -47,24 +51,24 @@ class ApiKeyAuthService: | |||
| @staticmethod | |||
| def delete_provider_auth(tenant_id: str, binding_id: str): | |||
| data_source_api_key_binding = db.session.query(DataSourceApiKeyAuthBinding).filter( | |||
| DataSourceApiKeyAuthBinding.tenant_id == tenant_id, | |||
| DataSourceApiKeyAuthBinding.id == binding_id | |||
| ).first() | |||
| data_source_api_key_binding = ( | |||
| db.session.query(DataSourceApiKeyAuthBinding) | |||
| .filter(DataSourceApiKeyAuthBinding.tenant_id == tenant_id, DataSourceApiKeyAuthBinding.id == binding_id) | |||
| .first() | |||
| ) | |||
| if data_source_api_key_binding: | |||
| db.session.delete(data_source_api_key_binding) | |||
| db.session.commit() | |||
| @classmethod | |||
| def validate_api_key_auth_args(cls, args): | |||
| if 'category' not in args or not args['category']: | |||
| raise ValueError('category is required') | |||
| if 'provider' not in args or not args['provider']: | |||
| raise ValueError('provider is required') | |||
| if 'credentials' not in args or not args['credentials']: | |||
| raise ValueError('credentials is required') | |||
| if not isinstance(args['credentials'], dict): | |||
| raise ValueError('credentials must be a dictionary') | |||
| if 'auth_type' not in args['credentials'] or not args['credentials']['auth_type']: | |||
| raise ValueError('auth_type is required') | |||
| if "category" not in args or not args["category"]: | |||
| raise ValueError("category is required") | |||
| if "provider" not in args or not args["provider"]: | |||
| raise ValueError("provider is required") | |||
| if "credentials" not in args or not args["credentials"]: | |||
| raise ValueError("credentials is required") | |||
| if not isinstance(args["credentials"], dict): | |||
| raise ValueError("credentials must be a dictionary") | |||
| if "auth_type" not in args["credentials"] or not args["credentials"]["auth_type"]: | |||
| raise ValueError("auth_type is required") | |||
| @@ -8,49 +8,40 @@ from services.auth.api_key_auth_base import ApiKeyAuthBase | |||
| class FirecrawlAuth(ApiKeyAuthBase): | |||
| def __init__(self, credentials: dict): | |||
| super().__init__(credentials) | |||
| auth_type = credentials.get('auth_type') | |||
| if auth_type != 'bearer': | |||
| raise ValueError('Invalid auth type, Firecrawl auth type must be Bearer') | |||
| self.api_key = credentials.get('config').get('api_key', None) | |||
| self.base_url = credentials.get('config').get('base_url', 'https://api.firecrawl.dev') | |||
| auth_type = credentials.get("auth_type") | |||
| if auth_type != "bearer": | |||
| raise ValueError("Invalid auth type, Firecrawl auth type must be Bearer") | |||
| self.api_key = credentials.get("config").get("api_key", None) | |||
| self.base_url = credentials.get("config").get("base_url", "https://api.firecrawl.dev") | |||
| if not self.api_key: | |||
| raise ValueError('No API key provided') | |||
| raise ValueError("No API key provided") | |||
| def validate_credentials(self): | |||
| headers = self._prepare_headers() | |||
| options = { | |||
| 'url': 'https://example.com', | |||
| 'crawlerOptions': { | |||
| 'excludes': [], | |||
| 'includes': [], | |||
| 'limit': 1 | |||
| }, | |||
| 'pageOptions': { | |||
| 'onlyMainContent': True | |||
| } | |||
| "url": "https://example.com", | |||
| "crawlerOptions": {"excludes": [], "includes": [], "limit": 1}, | |||
| "pageOptions": {"onlyMainContent": True}, | |||
| } | |||
| response = self._post_request(f'{self.base_url}/v0/crawl', options, headers) | |||
| response = self._post_request(f"{self.base_url}/v0/crawl", options, headers) | |||
| if response.status_code == 200: | |||
| return True | |||
| else: | |||
| self._handle_error(response) | |||
| def _prepare_headers(self): | |||
| return { | |||
| 'Content-Type': 'application/json', | |||
| 'Authorization': f'Bearer {self.api_key}' | |||
| } | |||
| return {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"} | |||
| def _post_request(self, url, data, headers): | |||
| return requests.post(url, headers=headers, json=data) | |||
| def _handle_error(self, response): | |||
| if response.status_code in [402, 409, 500]: | |||
| error_message = response.json().get('error', 'Unknown error occurred') | |||
| raise Exception(f'Failed to authorize. Status code: {response.status_code}. Error: {error_message}') | |||
| error_message = response.json().get("error", "Unknown error occurred") | |||
| raise Exception(f"Failed to authorize. Status code: {response.status_code}. Error: {error_message}") | |||
| else: | |||
| if response.text: | |||
| error_message = json.loads(response.text).get('error', 'Unknown error occurred') | |||
| raise Exception(f'Failed to authorize. Status code: {response.status_code}. Error: {error_message}') | |||
| raise Exception(f'Unexpected error occurred while trying to authorize. Status code: {response.status_code}') | |||
| error_message = json.loads(response.text).get("error", "Unknown error occurred") | |||
| raise Exception(f"Failed to authorize. Status code: {response.status_code}. Error: {error_message}") | |||
| raise Exception(f"Unexpected error occurred while trying to authorize. Status code: {response.status_code}") | |||
| @@ -7,58 +7,40 @@ from models.account import TenantAccountJoin, TenantAccountRole | |||
| class BillingService: | |||
| base_url = os.environ.get('BILLING_API_URL', 'BILLING_API_URL') | |||
| secret_key = os.environ.get('BILLING_API_SECRET_KEY', 'BILLING_API_SECRET_KEY') | |||
| base_url = os.environ.get("BILLING_API_URL", "BILLING_API_URL") | |||
| secret_key = os.environ.get("BILLING_API_SECRET_KEY", "BILLING_API_SECRET_KEY") | |||
| @classmethod | |||
| def get_info(cls, tenant_id: str): | |||
| params = {'tenant_id': tenant_id} | |||
| params = {"tenant_id": tenant_id} | |||
| billing_info = cls._send_request('GET', '/subscription/info', params=params) | |||
| billing_info = cls._send_request("GET", "/subscription/info", params=params) | |||
| return billing_info | |||
| @classmethod | |||
| def get_subscription(cls, plan: str, | |||
| interval: str, | |||
| prefilled_email: str = '', | |||
| tenant_id: str = ''): | |||
| params = { | |||
| 'plan': plan, | |||
| 'interval': interval, | |||
| 'prefilled_email': prefilled_email, | |||
| 'tenant_id': tenant_id | |||
| } | |||
| return cls._send_request('GET', '/subscription/payment-link', params=params) | |||
| def get_subscription(cls, plan: str, interval: str, prefilled_email: str = "", tenant_id: str = ""): | |||
| params = {"plan": plan, "interval": interval, "prefilled_email": prefilled_email, "tenant_id": tenant_id} | |||
| return cls._send_request("GET", "/subscription/payment-link", params=params) | |||
| @classmethod | |||
| def get_model_provider_payment_link(cls, | |||
| provider_name: str, | |||
| tenant_id: str, | |||
| account_id: str, | |||
| prefilled_email: str): | |||
| def get_model_provider_payment_link(cls, provider_name: str, tenant_id: str, account_id: str, prefilled_email: str): | |||
| params = { | |||
| 'provider_name': provider_name, | |||
| 'tenant_id': tenant_id, | |||
| 'account_id': account_id, | |||
| 'prefilled_email': prefilled_email | |||
| "provider_name": provider_name, | |||
| "tenant_id": tenant_id, | |||
| "account_id": account_id, | |||
| "prefilled_email": prefilled_email, | |||
| } | |||
| return cls._send_request('GET', '/model-provider/payment-link', params=params) | |||
| return cls._send_request("GET", "/model-provider/payment-link", params=params) | |||
| @classmethod | |||
| def get_invoices(cls, prefilled_email: str = '', tenant_id: str = ''): | |||
| params = { | |||
| 'prefilled_email': prefilled_email, | |||
| 'tenant_id': tenant_id | |||
| } | |||
| return cls._send_request('GET', '/invoices', params=params) | |||
| def get_invoices(cls, prefilled_email: str = "", tenant_id: str = ""): | |||
| params = {"prefilled_email": prefilled_email, "tenant_id": tenant_id} | |||
| return cls._send_request("GET", "/invoices", params=params) | |||
| @classmethod | |||
| def _send_request(cls, method, endpoint, json=None, params=None): | |||
| headers = { | |||
| "Content-Type": "application/json", | |||
| "Billing-Api-Secret-Key": cls.secret_key | |||
| } | |||
| headers = {"Content-Type": "application/json", "Billing-Api-Secret-Key": cls.secret_key} | |||
| url = f"{cls.base_url}{endpoint}" | |||
| response = requests.request(method, url, json=json, params=params, headers=headers) | |||
| @@ -69,10 +51,11 @@ class BillingService: | |||
| def is_tenant_owner_or_admin(current_user): | |||
| tenant_id = current_user.current_tenant_id | |||
| join = db.session.query(TenantAccountJoin).filter( | |||
| TenantAccountJoin.tenant_id == tenant_id, | |||
| TenantAccountJoin.account_id == current_user.id | |||
| ).first() | |||
| join = ( | |||
| db.session.query(TenantAccountJoin) | |||
| .filter(TenantAccountJoin.tenant_id == tenant_id, TenantAccountJoin.account_id == current_user.id) | |||
| .first() | |||
| ) | |||
| if not TenantAccountRole.is_privileged_role(join.role): | |||
| raise ValueError('Only team owner or team admin can perform this action') | |||
| raise ValueError("Only team owner or team admin can perform this action") | |||
| @@ -2,12 +2,15 @@ from extensions.ext_code_based_extension import code_based_extension | |||
| class CodeBasedExtensionService: | |||
| @staticmethod | |||
| def get_code_based_extension(module: str) -> list[dict]: | |||
| module_extensions = code_based_extension.module_extensions(module) | |||
| return [{ | |||
| 'name': module_extension.name, | |||
| 'label': module_extension.label, | |||
| 'form_schema': module_extension.form_schema | |||
| } for module_extension in module_extensions if not module_extension.builtin] | |||
| return [ | |||
| { | |||
| "name": module_extension.name, | |||
| "label": module_extension.label, | |||
| "form_schema": module_extension.form_schema, | |||
| } | |||
| for module_extension in module_extensions | |||
| if not module_extension.builtin | |||
| ] | |||
| @@ -15,22 +15,27 @@ from services.errors.message import MessageNotExistsError | |||
| class ConversationService: | |||
| @classmethod | |||
| def pagination_by_last_id(cls, app_model: App, user: Optional[Union[Account, EndUser]], | |||
| last_id: Optional[str], limit: int, | |||
| invoke_from: InvokeFrom, | |||
| include_ids: Optional[list] = None, | |||
| exclude_ids: Optional[list] = None, | |||
| sort_by: str = '-updated_at') -> InfiniteScrollPagination: | |||
| def pagination_by_last_id( | |||
| cls, | |||
| app_model: App, | |||
| user: Optional[Union[Account, EndUser]], | |||
| last_id: Optional[str], | |||
| limit: int, | |||
| invoke_from: InvokeFrom, | |||
| include_ids: Optional[list] = None, | |||
| exclude_ids: Optional[list] = None, | |||
| sort_by: str = "-updated_at", | |||
| ) -> InfiniteScrollPagination: | |||
| if not user: | |||
| return InfiniteScrollPagination(data=[], limit=limit, has_more=False) | |||
| base_query = db.session.query(Conversation).filter( | |||
| Conversation.is_deleted == False, | |||
| Conversation.app_id == app_model.id, | |||
| Conversation.from_source == ('api' if isinstance(user, EndUser) else 'console'), | |||
| Conversation.from_source == ("api" if isinstance(user, EndUser) else "console"), | |||
| Conversation.from_end_user_id == (user.id if isinstance(user, EndUser) else None), | |||
| Conversation.from_account_id == (user.id if isinstance(user, Account) else None), | |||
| or_(Conversation.invoke_from.is_(None), Conversation.invoke_from == invoke_from.value) | |||
| or_(Conversation.invoke_from.is_(None), Conversation.invoke_from == invoke_from.value), | |||
| ) | |||
| if include_ids is not None: | |||
| @@ -58,28 +63,26 @@ class ConversationService: | |||
| has_more = False | |||
| if len(conversations) == limit: | |||
| current_page_last_conversation = conversations[-1] | |||
| rest_filter_condition = cls._build_filter_condition(sort_field, sort_direction, | |||
| current_page_last_conversation, is_next_page=True) | |||
| rest_filter_condition = cls._build_filter_condition( | |||
| sort_field, sort_direction, current_page_last_conversation, is_next_page=True | |||
| ) | |||
| rest_count = base_query.filter(rest_filter_condition).count() | |||
| if rest_count > 0: | |||
| has_more = True | |||
| return InfiniteScrollPagination( | |||
| data=conversations, | |||
| limit=limit, | |||
| has_more=has_more | |||
| ) | |||
| return InfiniteScrollPagination(data=conversations, limit=limit, has_more=has_more) | |||
| @classmethod | |||
| def _get_sort_params(cls, sort_by: str) -> tuple[str, callable]: | |||
| if sort_by.startswith('-'): | |||
| if sort_by.startswith("-"): | |||
| return sort_by[1:], desc | |||
| return sort_by, asc | |||
| @classmethod | |||
| def _build_filter_condition(cls, sort_field: str, sort_direction: callable, reference_conversation: Conversation, | |||
| is_next_page: bool = False): | |||
| def _build_filter_condition( | |||
| cls, sort_field: str, sort_direction: callable, reference_conversation: Conversation, is_next_page: bool = False | |||
| ): | |||
| field_value = getattr(reference_conversation, sort_field) | |||
| if (sort_direction == desc and not is_next_page) or (sort_direction == asc and is_next_page): | |||
| return getattr(Conversation, sort_field) < field_value | |||
| @@ -87,8 +90,14 @@ class ConversationService: | |||
| return getattr(Conversation, sort_field) > field_value | |||
| @classmethod | |||
| def rename(cls, app_model: App, conversation_id: str, | |||
| user: Optional[Union[Account, EndUser]], name: str, auto_generate: bool): | |||
| def rename( | |||
| cls, | |||
| app_model: App, | |||
| conversation_id: str, | |||
| user: Optional[Union[Account, EndUser]], | |||
| name: str, | |||
| auto_generate: bool, | |||
| ): | |||
| conversation = cls.get_conversation(app_model, conversation_id, user) | |||
| if auto_generate: | |||
| @@ -103,11 +112,12 @@ class ConversationService: | |||
| @classmethod | |||
| def auto_generate_name(cls, app_model: App, conversation: Conversation): | |||
| # get conversation first message | |||
| message = db.session.query(Message) \ | |||
| .filter( | |||
| Message.app_id == app_model.id, | |||
| Message.conversation_id == conversation.id | |||
| ).order_by(Message.created_at.asc()).first() | |||
| message = ( | |||
| db.session.query(Message) | |||
| .filter(Message.app_id == app_model.id, Message.conversation_id == conversation.id) | |||
| .order_by(Message.created_at.asc()) | |||
| .first() | |||
| ) | |||
| if not message: | |||
| raise MessageNotExistsError() | |||
| @@ -127,15 +137,18 @@ class ConversationService: | |||
| @classmethod | |||
| def get_conversation(cls, app_model: App, conversation_id: str, user: Optional[Union[Account, EndUser]]): | |||
| conversation = db.session.query(Conversation) \ | |||
| conversation = ( | |||
| db.session.query(Conversation) | |||
| .filter( | |||
| Conversation.id == conversation_id, | |||
| Conversation.app_id == app_model.id, | |||
| Conversation.from_source == ('api' if isinstance(user, EndUser) else 'console'), | |||
| Conversation.from_end_user_id == (user.id if isinstance(user, EndUser) else None), | |||
| Conversation.from_account_id == (user.id if isinstance(user, Account) else None), | |||
| Conversation.is_deleted == False | |||
| ).first() | |||
| Conversation.id == conversation_id, | |||
| Conversation.app_id == app_model.id, | |||
| Conversation.from_source == ("api" if isinstance(user, EndUser) else "console"), | |||
| Conversation.from_end_user_id == (user.id if isinstance(user, EndUser) else None), | |||
| Conversation.from_account_id == (user.id if isinstance(user, Account) else None), | |||
| Conversation.is_deleted == False, | |||
| ) | |||
| .first() | |||
| ) | |||
| if not conversation: | |||
| raise ConversationNotExistsError() | |||
| @@ -4,15 +4,12 @@ import requests | |||
| class EnterpriseRequest: | |||
| base_url = os.environ.get('ENTERPRISE_API_URL', 'ENTERPRISE_API_URL') | |||
| secret_key = os.environ.get('ENTERPRISE_API_SECRET_KEY', 'ENTERPRISE_API_SECRET_KEY') | |||
| base_url = os.environ.get("ENTERPRISE_API_URL", "ENTERPRISE_API_URL") | |||
| secret_key = os.environ.get("ENTERPRISE_API_SECRET_KEY", "ENTERPRISE_API_SECRET_KEY") | |||
| @classmethod | |||
| def send_request(cls, method, endpoint, json=None, params=None): | |||
| headers = { | |||
| "Content-Type": "application/json", | |||
| "Enterprise-Api-Secret-Key": cls.secret_key | |||
| } | |||
| headers = {"Content-Type": "application/json", "Enterprise-Api-Secret-Key": cls.secret_key} | |||
| url = f"{cls.base_url}{endpoint}" | |||
| response = requests.request(method, url, json=json, params=params, headers=headers) | |||
| @@ -2,11 +2,10 @@ from services.enterprise.base import EnterpriseRequest | |||
| class EnterpriseService: | |||
| @classmethod | |||
| def get_info(cls): | |||
| return EnterpriseRequest.send_request('GET', '/info') | |||
| return EnterpriseRequest.send_request("GET", "/info") | |||
| @classmethod | |||
| def get_app_web_sso_enabled(cls, app_code): | |||
| return EnterpriseRequest.send_request('GET', f'/app-sso-setting?appCode={app_code}') | |||
| return EnterpriseRequest.send_request("GET", f"/app-sso-setting?appCode={app_code}") | |||
| @@ -22,14 +22,16 @@ class CustomConfigurationStatus(Enum): | |||
| """ | |||
| Enum class for custom configuration status. | |||
| """ | |||
| ACTIVE = 'active' | |||
| NO_CONFIGURE = 'no-configure' | |||
| ACTIVE = "active" | |||
| NO_CONFIGURE = "no-configure" | |||
| class CustomConfigurationResponse(BaseModel): | |||
| """ | |||
| Model class for provider custom configuration response. | |||
| """ | |||
| status: CustomConfigurationStatus | |||
| @@ -37,6 +39,7 @@ class SystemConfigurationResponse(BaseModel): | |||
| """ | |||
| Model class for provider system configuration response. | |||
| """ | |||
| enabled: bool | |||
| current_quota_type: Optional[ProviderQuotaType] = None | |||
| quota_configurations: list[QuotaConfiguration] = [] | |||
| @@ -46,6 +49,7 @@ class ProviderResponse(BaseModel): | |||
| """ | |||
| Model class for provider response. | |||
| """ | |||
| provider: str | |||
| label: I18nObject | |||
| description: Optional[I18nObject] = None | |||
| @@ -67,18 +71,15 @@ class ProviderResponse(BaseModel): | |||
| def __init__(self, **data) -> None: | |||
| super().__init__(**data) | |||
| url_prefix = (dify_config.CONSOLE_API_URL | |||
| + f"/console/api/workspaces/current/model-providers/{self.provider}") | |||
| url_prefix = dify_config.CONSOLE_API_URL + f"/console/api/workspaces/current/model-providers/{self.provider}" | |||
| if self.icon_small is not None: | |||
| self.icon_small = I18nObject( | |||
| en_US=f"{url_prefix}/icon_small/en_US", | |||
| zh_Hans=f"{url_prefix}/icon_small/zh_Hans" | |||
| en_US=f"{url_prefix}/icon_small/en_US", zh_Hans=f"{url_prefix}/icon_small/zh_Hans" | |||
| ) | |||
| if self.icon_large is not None: | |||
| self.icon_large = I18nObject( | |||
| en_US=f"{url_prefix}/icon_large/en_US", | |||
| zh_Hans=f"{url_prefix}/icon_large/zh_Hans" | |||
| en_US=f"{url_prefix}/icon_large/en_US", zh_Hans=f"{url_prefix}/icon_large/zh_Hans" | |||
| ) | |||
| @@ -86,6 +87,7 @@ class ProviderWithModelsResponse(BaseModel): | |||
| """ | |||
| Model class for provider with models response. | |||
| """ | |||
| provider: str | |||
| label: I18nObject | |||
| icon_small: Optional[I18nObject] = None | |||
| @@ -96,18 +98,15 @@ class ProviderWithModelsResponse(BaseModel): | |||
| def __init__(self, **data) -> None: | |||
| super().__init__(**data) | |||
| url_prefix = (dify_config.CONSOLE_API_URL | |||
| + f"/console/api/workspaces/current/model-providers/{self.provider}") | |||
| url_prefix = dify_config.CONSOLE_API_URL + f"/console/api/workspaces/current/model-providers/{self.provider}" | |||
| if self.icon_small is not None: | |||
| self.icon_small = I18nObject( | |||
| en_US=f"{url_prefix}/icon_small/en_US", | |||
| zh_Hans=f"{url_prefix}/icon_small/zh_Hans" | |||
| en_US=f"{url_prefix}/icon_small/en_US", zh_Hans=f"{url_prefix}/icon_small/zh_Hans" | |||
| ) | |||
| if self.icon_large is not None: | |||
| self.icon_large = I18nObject( | |||
| en_US=f"{url_prefix}/icon_large/en_US", | |||
| zh_Hans=f"{url_prefix}/icon_large/zh_Hans" | |||
| en_US=f"{url_prefix}/icon_large/en_US", zh_Hans=f"{url_prefix}/icon_large/zh_Hans" | |||
| ) | |||
| @@ -119,18 +118,15 @@ class SimpleProviderEntityResponse(SimpleProviderEntity): | |||
| def __init__(self, **data) -> None: | |||
| super().__init__(**data) | |||
| url_prefix = (dify_config.CONSOLE_API_URL | |||
| + f"/console/api/workspaces/current/model-providers/{self.provider}") | |||
| url_prefix = dify_config.CONSOLE_API_URL + f"/console/api/workspaces/current/model-providers/{self.provider}" | |||
| if self.icon_small is not None: | |||
| self.icon_small = I18nObject( | |||
| en_US=f"{url_prefix}/icon_small/en_US", | |||
| zh_Hans=f"{url_prefix}/icon_small/zh_Hans" | |||
| en_US=f"{url_prefix}/icon_small/en_US", zh_Hans=f"{url_prefix}/icon_small/zh_Hans" | |||
| ) | |||
| if self.icon_large is not None: | |||
| self.icon_large = I18nObject( | |||
| en_US=f"{url_prefix}/icon_large/en_US", | |||
| zh_Hans=f"{url_prefix}/icon_large/zh_Hans" | |||
| en_US=f"{url_prefix}/icon_large/en_US", zh_Hans=f"{url_prefix}/icon_large/zh_Hans" | |||
| ) | |||
| @@ -138,6 +134,7 @@ class DefaultModelResponse(BaseModel): | |||
| """ | |||
| Default model entity. | |||
| """ | |||
| model: str | |||
| model_type: ModelType | |||
| provider: SimpleProviderEntityResponse | |||
| @@ -150,6 +147,7 @@ class ModelWithProviderEntityResponse(ModelWithProviderEntity): | |||
| """ | |||
| Model with provider entity. | |||
| """ | |||
| provider: SimpleProviderEntityResponse | |||
| def __init__(self, model: ModelWithProviderEntity) -> None: | |||
| @@ -55,4 +55,3 @@ class RoleAlreadyAssignedError(BaseServiceError): | |||
| class RateLimitExceededError(BaseServiceError): | |||
| pass | |||
| @@ -1,3 +1,3 @@ | |||
| class BaseServiceError(Exception): | |||
| def __init__(self, description: str = None): | |||
| self.description = description | |||
| self.description = description | |||
| @@ -6,8 +6,8 @@ from services.enterprise.enterprise_service import EnterpriseService | |||
| class SubscriptionModel(BaseModel): | |||
| plan: str = 'sandbox' | |||
| interval: str = '' | |||
| plan: str = "sandbox" | |||
| interval: str = "" | |||
| class BillingModel(BaseModel): | |||
| @@ -27,7 +27,7 @@ class FeatureModel(BaseModel): | |||
| vector_space: LimitationModel = LimitationModel(size=0, limit=5) | |||
| annotation_quota_limit: LimitationModel = LimitationModel(size=0, limit=10) | |||
| documents_upload_quota: LimitationModel = LimitationModel(size=0, limit=50) | |||
| docs_processing: str = 'standard' | |||
| docs_processing: str = "standard" | |||
| can_replace_logo: bool = False | |||
| model_load_balancing_enabled: bool = False | |||
| dataset_operator_enabled: bool = False | |||
| @@ -38,13 +38,13 @@ class FeatureModel(BaseModel): | |||
| class SystemFeatureModel(BaseModel): | |||
| sso_enforced_for_signin: bool = False | |||
| sso_enforced_for_signin_protocol: str = '' | |||
| sso_enforced_for_signin_protocol: str = "" | |||
| sso_enforced_for_web: bool = False | |||
| sso_enforced_for_web_protocol: str = '' | |||
| sso_enforced_for_web_protocol: str = "" | |||
| enable_web_sso_switch_component: bool = False | |||
| class FeatureService: | |||
| class FeatureService: | |||
| @classmethod | |||
| def get_features(cls, tenant_id: str) -> FeatureModel: | |||
| features = FeatureModel() | |||
| @@ -76,44 +76,44 @@ class FeatureService: | |||
| def _fulfill_params_from_billing_api(cls, features: FeatureModel, tenant_id: str): | |||
| billing_info = BillingService.get_info(tenant_id) | |||
| features.billing.enabled = billing_info['enabled'] | |||
| features.billing.subscription.plan = billing_info['subscription']['plan'] | |||
| features.billing.subscription.interval = billing_info['subscription']['interval'] | |||
| features.billing.enabled = billing_info["enabled"] | |||
| features.billing.subscription.plan = billing_info["subscription"]["plan"] | |||
| features.billing.subscription.interval = billing_info["subscription"]["interval"] | |||
| if 'members' in billing_info: | |||
| features.members.size = billing_info['members']['size'] | |||
| features.members.limit = billing_info['members']['limit'] | |||
| if "members" in billing_info: | |||
| features.members.size = billing_info["members"]["size"] | |||
| features.members.limit = billing_info["members"]["limit"] | |||
| if 'apps' in billing_info: | |||
| features.apps.size = billing_info['apps']['size'] | |||
| features.apps.limit = billing_info['apps']['limit'] | |||
| if "apps" in billing_info: | |||
| features.apps.size = billing_info["apps"]["size"] | |||
| features.apps.limit = billing_info["apps"]["limit"] | |||
| if 'vector_space' in billing_info: | |||
| features.vector_space.size = billing_info['vector_space']['size'] | |||
| features.vector_space.limit = billing_info['vector_space']['limit'] | |||
| if "vector_space" in billing_info: | |||
| features.vector_space.size = billing_info["vector_space"]["size"] | |||
| features.vector_space.limit = billing_info["vector_space"]["limit"] | |||
| if 'documents_upload_quota' in billing_info: | |||
| features.documents_upload_quota.size = billing_info['documents_upload_quota']['size'] | |||
| features.documents_upload_quota.limit = billing_info['documents_upload_quota']['limit'] | |||
| if "documents_upload_quota" in billing_info: | |||
| features.documents_upload_quota.size = billing_info["documents_upload_quota"]["size"] | |||
| features.documents_upload_quota.limit = billing_info["documents_upload_quota"]["limit"] | |||
| if 'annotation_quota_limit' in billing_info: | |||
| features.annotation_quota_limit.size = billing_info['annotation_quota_limit']['size'] | |||
| features.annotation_quota_limit.limit = billing_info['annotation_quota_limit']['limit'] | |||
| if "annotation_quota_limit" in billing_info: | |||
| features.annotation_quota_limit.size = billing_info["annotation_quota_limit"]["size"] | |||
| features.annotation_quota_limit.limit = billing_info["annotation_quota_limit"]["limit"] | |||
| if 'docs_processing' in billing_info: | |||
| features.docs_processing = billing_info['docs_processing'] | |||
| if "docs_processing" in billing_info: | |||
| features.docs_processing = billing_info["docs_processing"] | |||
| if 'can_replace_logo' in billing_info: | |||
| features.can_replace_logo = billing_info['can_replace_logo'] | |||
| if "can_replace_logo" in billing_info: | |||
| features.can_replace_logo = billing_info["can_replace_logo"] | |||
| if 'model_load_balancing_enabled' in billing_info: | |||
| features.model_load_balancing_enabled = billing_info['model_load_balancing_enabled'] | |||
| if "model_load_balancing_enabled" in billing_info: | |||
| features.model_load_balancing_enabled = billing_info["model_load_balancing_enabled"] | |||
| @classmethod | |||
| def _fulfill_params_from_enterprise(cls, features): | |||
| enterprise_info = EnterpriseService.get_info() | |||
| features.sso_enforced_for_signin = enterprise_info['sso_enforced_for_signin'] | |||
| features.sso_enforced_for_signin_protocol = enterprise_info['sso_enforced_for_signin_protocol'] | |||
| features.sso_enforced_for_web = enterprise_info['sso_enforced_for_web'] | |||
| features.sso_enforced_for_web_protocol = enterprise_info['sso_enforced_for_web_protocol'] | |||
| features.sso_enforced_for_signin = enterprise_info["sso_enforced_for_signin"] | |||
| features.sso_enforced_for_signin_protocol = enterprise_info["sso_enforced_for_signin_protocol"] | |||
| features.sso_enforced_for_web = enterprise_info["sso_enforced_for_web"] | |||
| features.sso_enforced_for_web_protocol = enterprise_info["sso_enforced_for_web_protocol"] | |||
| @@ -17,27 +17,45 @@ from models.account import Account | |||
| from models.model import EndUser, UploadFile | |||
| from services.errors.file import FileTooLargeError, UnsupportedFileTypeError | |||
| IMAGE_EXTENSIONS = ['jpg', 'jpeg', 'png', 'webp', 'gif', 'svg'] | |||
| IMAGE_EXTENSIONS = ["jpg", "jpeg", "png", "webp", "gif", "svg"] | |||
| IMAGE_EXTENSIONS.extend([ext.upper() for ext in IMAGE_EXTENSIONS]) | |||
| ALLOWED_EXTENSIONS = ['txt', 'markdown', 'md', 'pdf', 'html', 'htm', 'xlsx', 'xls', 'docx', 'csv'] | |||
| UNSTRUCTURED_ALLOWED_EXTENSIONS = ['txt', 'markdown', 'md', 'pdf', 'html', 'htm', 'xlsx', 'xls', | |||
| 'docx', 'csv', 'eml', 'msg', 'pptx', 'ppt', 'xml', 'epub'] | |||
| ALLOWED_EXTENSIONS = ["txt", "markdown", "md", "pdf", "html", "htm", "xlsx", "xls", "docx", "csv"] | |||
| UNSTRUCTURED_ALLOWED_EXTENSIONS = [ | |||
| "txt", | |||
| "markdown", | |||
| "md", | |||
| "pdf", | |||
| "html", | |||
| "htm", | |||
| "xlsx", | |||
| "xls", | |||
| "docx", | |||
| "csv", | |||
| "eml", | |||
| "msg", | |||
| "pptx", | |||
| "ppt", | |||
| "xml", | |||
| "epub", | |||
| ] | |||
| PREVIEW_WORDS_LIMIT = 3000 | |||
| class FileService: | |||
| @staticmethod | |||
| def upload_file(file: FileStorage, user: Union[Account, EndUser], only_image: bool = False) -> UploadFile: | |||
| filename = file.filename | |||
| extension = file.filename.split('.')[-1] | |||
| extension = file.filename.split(".")[-1] | |||
| if len(filename) > 200: | |||
| filename = filename.split('.')[0][:200] + '.' + extension | |||
| filename = filename.split(".")[0][:200] + "." + extension | |||
| etl_type = dify_config.ETL_TYPE | |||
| allowed_extensions = UNSTRUCTURED_ALLOWED_EXTENSIONS + IMAGE_EXTENSIONS if etl_type == 'Unstructured' \ | |||
| allowed_extensions = ( | |||
| UNSTRUCTURED_ALLOWED_EXTENSIONS + IMAGE_EXTENSIONS | |||
| if etl_type == "Unstructured" | |||
| else ALLOWED_EXTENSIONS + IMAGE_EXTENSIONS | |||
| ) | |||
| if extension.lower() not in allowed_extensions: | |||
| raise UnsupportedFileTypeError() | |||
| elif only_image and extension.lower() not in IMAGE_EXTENSIONS: | |||
| @@ -55,7 +73,7 @@ class FileService: | |||
| file_size_limit = dify_config.UPLOAD_FILE_SIZE_LIMIT * 1024 * 1024 | |||
| if file_size > file_size_limit: | |||
| message = f'File size exceeded. {file_size} > {file_size_limit}' | |||
| message = f"File size exceeded. {file_size} > {file_size_limit}" | |||
| raise FileTooLargeError(message) | |||
| # user uuid as file name | |||
| @@ -67,7 +85,7 @@ class FileService: | |||
| # end_user | |||
| current_tenant_id = user.tenant_id | |||
| file_key = 'upload_files/' + current_tenant_id + '/' + file_uuid + '.' + extension | |||
| file_key = "upload_files/" + current_tenant_id + "/" + file_uuid + "." + extension | |||
| # save file to storage | |||
| storage.save(file_key, file_content) | |||
| @@ -81,11 +99,11 @@ class FileService: | |||
| size=file_size, | |||
| extension=extension, | |||
| mime_type=file.mimetype, | |||
| created_by_role=('account' if isinstance(user, Account) else 'end_user'), | |||
| created_by_role=("account" if isinstance(user, Account) else "end_user"), | |||
| created_by=user.id, | |||
| created_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), | |||
| used=False, | |||
| hash=hashlib.sha3_256(file_content).hexdigest() | |||
| hash=hashlib.sha3_256(file_content).hexdigest(), | |||
| ) | |||
| db.session.add(upload_file) | |||
| @@ -99,10 +117,10 @@ class FileService: | |||
| text_name = text_name[:200] | |||
| # user uuid as file name | |||
| file_uuid = str(uuid.uuid4()) | |||
| file_key = 'upload_files/' + current_user.current_tenant_id + '/' + file_uuid + '.txt' | |||
| file_key = "upload_files/" + current_user.current_tenant_id + "/" + file_uuid + ".txt" | |||
| # save file to storage | |||
| storage.save(file_key, text.encode('utf-8')) | |||
| storage.save(file_key, text.encode("utf-8")) | |||
| # save file to db | |||
| upload_file = UploadFile( | |||
| @@ -111,13 +129,13 @@ class FileService: | |||
| key=file_key, | |||
| name=text_name, | |||
| size=len(text), | |||
| extension='txt', | |||
| mime_type='text/plain', | |||
| extension="txt", | |||
| mime_type="text/plain", | |||
| created_by=current_user.id, | |||
| created_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), | |||
| used=True, | |||
| used_by=current_user.id, | |||
| used_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) | |||
| used_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), | |||
| ) | |||
| db.session.add(upload_file) | |||
| @@ -127,9 +145,7 @@ class FileService: | |||
| @staticmethod | |||
| def get_file_preview(file_id: str) -> str: | |||
| upload_file = db.session.query(UploadFile) \ | |||
| .filter(UploadFile.id == file_id) \ | |||
| .first() | |||
| upload_file = db.session.query(UploadFile).filter(UploadFile.id == file_id).first() | |||
| if not upload_file: | |||
| raise NotFound("File not found") | |||
| @@ -137,12 +153,12 @@ class FileService: | |||
| # extract text from file | |||
| extension = upload_file.extension | |||
| etl_type = dify_config.ETL_TYPE | |||
| allowed_extensions = UNSTRUCTURED_ALLOWED_EXTENSIONS if etl_type == 'Unstructured' else ALLOWED_EXTENSIONS | |||
| allowed_extensions = UNSTRUCTURED_ALLOWED_EXTENSIONS if etl_type == "Unstructured" else ALLOWED_EXTENSIONS | |||
| if extension.lower() not in allowed_extensions: | |||
| raise UnsupportedFileTypeError() | |||
| text = ExtractProcessor.load_from_upload_file(upload_file, return_text=True) | |||
| text = text[0:PREVIEW_WORDS_LIMIT] if text else '' | |||
| text = text[0:PREVIEW_WORDS_LIMIT] if text else "" | |||
| return text | |||
| @@ -152,9 +168,7 @@ class FileService: | |||
| if not result: | |||
| raise NotFound("File not found or signature is invalid") | |||
| upload_file = db.session.query(UploadFile) \ | |||
| .filter(UploadFile.id == file_id) \ | |||
| .first() | |||
| upload_file = db.session.query(UploadFile).filter(UploadFile.id == file_id).first() | |||
| if not upload_file: | |||
| raise NotFound("File not found or signature is invalid") | |||
| @@ -170,9 +184,7 @@ class FileService: | |||
| @staticmethod | |||
| def get_public_image_preview(file_id: str) -> tuple[Generator, str]: | |||
| upload_file = db.session.query(UploadFile) \ | |||
| .filter(UploadFile.id == file_id) \ | |||
| .first() | |||
| upload_file = db.session.query(UploadFile).filter(UploadFile.id == file_id).first() | |||
| if not upload_file: | |||
| raise NotFound("File not found or signature is invalid") | |||
| @@ -9,14 +9,11 @@ from models.account import Account | |||
| from models.dataset import Dataset, DatasetQuery, DocumentSegment | |||
| default_retrieval_model = { | |||
| 'search_method': RetrievalMethod.SEMANTIC_SEARCH.value, | |||
| 'reranking_enable': False, | |||
| 'reranking_model': { | |||
| 'reranking_provider_name': '', | |||
| 'reranking_model_name': '' | |||
| }, | |||
| 'top_k': 2, | |||
| 'score_threshold_enabled': False | |||
| "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, | |||
| "reranking_enable": False, | |||
| "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, | |||
| "top_k": 2, | |||
| "score_threshold_enabled": False, | |||
| } | |||
| @@ -27,9 +24,9 @@ class HitTestingService: | |||
| return { | |||
| "query": { | |||
| "content": query, | |||
| "tsne_position": {'x': 0, 'y': 0}, | |||
| "tsne_position": {"x": 0, "y": 0}, | |||
| }, | |||
| "records": [] | |||
| "records": [], | |||
| } | |||
| start = time.perf_counter() | |||
| @@ -38,28 +35,28 @@ class HitTestingService: | |||
| if not retrieval_model: | |||
| retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model | |||
| all_documents = RetrievalService.retrieve(retrival_method=retrieval_model.get('search_method', 'semantic_search'), | |||
| dataset_id=dataset.id, | |||
| query=cls.escape_query_for_search(query), | |||
| top_k=retrieval_model.get('top_k', 2), | |||
| score_threshold=retrieval_model.get('score_threshold', .0) | |||
| if retrieval_model['score_threshold_enabled'] else None, | |||
| reranking_model=retrieval_model.get('reranking_model', None) | |||
| if retrieval_model['reranking_enable'] else None, | |||
| reranking_mode=retrieval_model.get('reranking_mode') | |||
| if retrieval_model.get('reranking_mode') else 'reranking_model', | |||
| weights=retrieval_model.get('weights', None), | |||
| ) | |||
| all_documents = RetrievalService.retrieve( | |||
| retrival_method=retrieval_model.get("search_method", "semantic_search"), | |||
| dataset_id=dataset.id, | |||
| query=cls.escape_query_for_search(query), | |||
| top_k=retrieval_model.get("top_k", 2), | |||
| score_threshold=retrieval_model.get("score_threshold", 0.0) | |||
| if retrieval_model["score_threshold_enabled"] | |||
| else None, | |||
| reranking_model=retrieval_model.get("reranking_model", None) | |||
| if retrieval_model["reranking_enable"] | |||
| else None, | |||
| reranking_mode=retrieval_model.get("reranking_mode") | |||
| if retrieval_model.get("reranking_mode") | |||
| else "reranking_model", | |||
| weights=retrieval_model.get("weights", None), | |||
| ) | |||
| end = time.perf_counter() | |||
| logging.debug(f"Hit testing retrieve in {end - start:0.4f} seconds") | |||
| dataset_query = DatasetQuery( | |||
| dataset_id=dataset.id, | |||
| content=query, | |||
| source='hit_testing', | |||
| created_by_role='account', | |||
| created_by=account.id | |||
| dataset_id=dataset.id, content=query, source="hit_testing", created_by_role="account", created_by=account.id | |||
| ) | |||
| db.session.add(dataset_query) | |||
| @@ -72,14 +69,18 @@ class HitTestingService: | |||
| i = 0 | |||
| records = [] | |||
| for document in documents: | |||
| index_node_id = document.metadata['doc_id'] | |||
| segment = db.session.query(DocumentSegment).filter( | |||
| DocumentSegment.dataset_id == dataset.id, | |||
| DocumentSegment.enabled == True, | |||
| DocumentSegment.status == 'completed', | |||
| DocumentSegment.index_node_id == index_node_id | |||
| ).first() | |||
| index_node_id = document.metadata["doc_id"] | |||
| segment = ( | |||
| db.session.query(DocumentSegment) | |||
| .filter( | |||
| DocumentSegment.dataset_id == dataset.id, | |||
| DocumentSegment.enabled == True, | |||
| DocumentSegment.status == "completed", | |||
| DocumentSegment.index_node_id == index_node_id, | |||
| ) | |||
| .first() | |||
| ) | |||
| if not segment: | |||
| i += 1 | |||
| @@ -87,7 +88,7 @@ class HitTestingService: | |||
| record = { | |||
| "segment": segment, | |||
| "score": document.metadata.get('score', None), | |||
| "score": document.metadata.get("score", None), | |||
| } | |||
| records.append(record) | |||
| @@ -98,15 +99,15 @@ class HitTestingService: | |||
| "query": { | |||
| "content": query, | |||
| }, | |||
| "records": records | |||
| "records": records, | |||
| } | |||
| @classmethod | |||
| def hit_testing_args_check(cls, args): | |||
| query = args['query'] | |||
| query = args["query"] | |||
| if not query or len(query) > 250: | |||
| raise ValueError('Query is required and cannot exceed 250 characters') | |||
| raise ValueError("Query is required and cannot exceed 250 characters") | |||
| @staticmethod | |||
| def escape_query_for_search(query: str) -> str: | |||
| @@ -27,8 +27,14 @@ from services.workflow_service import WorkflowService | |||
| class MessageService: | |||
| @classmethod | |||
| def pagination_by_first_id(cls, app_model: App, user: Optional[Union[Account, EndUser]], | |||
| conversation_id: str, first_id: Optional[str], limit: int) -> InfiniteScrollPagination: | |||
| def pagination_by_first_id( | |||
| cls, | |||
| app_model: App, | |||
| user: Optional[Union[Account, EndUser]], | |||
| conversation_id: str, | |||
| first_id: Optional[str], | |||
| limit: int, | |||
| ) -> InfiniteScrollPagination: | |||
| if not user: | |||
| return InfiniteScrollPagination(data=[], limit=limit, has_more=False) | |||
| @@ -36,52 +42,69 @@ class MessageService: | |||
| return InfiniteScrollPagination(data=[], limit=limit, has_more=False) | |||
| conversation = ConversationService.get_conversation( | |||
| app_model=app_model, | |||
| user=user, | |||
| conversation_id=conversation_id | |||
| app_model=app_model, user=user, conversation_id=conversation_id | |||
| ) | |||
| if first_id: | |||
| first_message = db.session.query(Message) \ | |||
| .filter(Message.conversation_id == conversation.id, Message.id == first_id).first() | |||
| first_message = ( | |||
| db.session.query(Message) | |||
| .filter(Message.conversation_id == conversation.id, Message.id == first_id) | |||
| .first() | |||
| ) | |||
| if not first_message: | |||
| raise FirstMessageNotExistsError() | |||
| history_messages = db.session.query(Message).filter( | |||
| Message.conversation_id == conversation.id, | |||
| Message.created_at < first_message.created_at, | |||
| Message.id != first_message.id | |||
| ) \ | |||
| .order_by(Message.created_at.desc()).limit(limit).all() | |||
| history_messages = ( | |||
| db.session.query(Message) | |||
| .filter( | |||
| Message.conversation_id == conversation.id, | |||
| Message.created_at < first_message.created_at, | |||
| Message.id != first_message.id, | |||
| ) | |||
| .order_by(Message.created_at.desc()) | |||
| .limit(limit) | |||
| .all() | |||
| ) | |||
| else: | |||
| history_messages = db.session.query(Message).filter(Message.conversation_id == conversation.id) \ | |||
| .order_by(Message.created_at.desc()).limit(limit).all() | |||
| history_messages = ( | |||
| db.session.query(Message) | |||
| .filter(Message.conversation_id == conversation.id) | |||
| .order_by(Message.created_at.desc()) | |||
| .limit(limit) | |||
| .all() | |||
| ) | |||
| has_more = False | |||
| if len(history_messages) == limit: | |||
| current_page_first_message = history_messages[-1] | |||
| rest_count = db.session.query(Message).filter( | |||
| Message.conversation_id == conversation.id, | |||
| Message.created_at < current_page_first_message.created_at, | |||
| Message.id != current_page_first_message.id | |||
| ).count() | |||
| rest_count = ( | |||
| db.session.query(Message) | |||
| .filter( | |||
| Message.conversation_id == conversation.id, | |||
| Message.created_at < current_page_first_message.created_at, | |||
| Message.id != current_page_first_message.id, | |||
| ) | |||
| .count() | |||
| ) | |||
| if rest_count > 0: | |||
| has_more = True | |||
| history_messages = list(reversed(history_messages)) | |||
| return InfiniteScrollPagination( | |||
| data=history_messages, | |||
| limit=limit, | |||
| has_more=has_more | |||
| ) | |||
| return InfiniteScrollPagination(data=history_messages, limit=limit, has_more=has_more) | |||
| @classmethod | |||
| def pagination_by_last_id(cls, app_model: App, user: Optional[Union[Account, EndUser]], | |||
| last_id: Optional[str], limit: int, conversation_id: Optional[str] = None, | |||
| include_ids: Optional[list] = None) -> InfiniteScrollPagination: | |||
| def pagination_by_last_id( | |||
| cls, | |||
| app_model: App, | |||
| user: Optional[Union[Account, EndUser]], | |||
| last_id: Optional[str], | |||
| limit: int, | |||
| conversation_id: Optional[str] = None, | |||
| include_ids: Optional[list] = None, | |||
| ) -> InfiniteScrollPagination: | |||
| if not user: | |||
| return InfiniteScrollPagination(data=[], limit=limit, has_more=False) | |||
| @@ -89,9 +112,7 @@ class MessageService: | |||
| if conversation_id is not None: | |||
| conversation = ConversationService.get_conversation( | |||
| app_model=app_model, | |||
| user=user, | |||
| conversation_id=conversation_id | |||
| app_model=app_model, user=user, conversation_id=conversation_id | |||
| ) | |||
| base_query = base_query.filter(Message.conversation_id == conversation.id) | |||
| @@ -105,10 +126,12 @@ class MessageService: | |||
| if not last_message: | |||
| raise LastMessageNotExistsError() | |||
| history_messages = base_query.filter( | |||
| Message.created_at < last_message.created_at, | |||
| Message.id != last_message.id | |||
| ).order_by(Message.created_at.desc()).limit(limit).all() | |||
| history_messages = ( | |||
| base_query.filter(Message.created_at < last_message.created_at, Message.id != last_message.id) | |||
| .order_by(Message.created_at.desc()) | |||
| .limit(limit) | |||
| .all() | |||
| ) | |||
| else: | |||
| history_messages = base_query.order_by(Message.created_at.desc()).limit(limit).all() | |||
| @@ -116,30 +139,22 @@ class MessageService: | |||
| if len(history_messages) == limit: | |||
| current_page_first_message = history_messages[-1] | |||
| rest_count = base_query.filter( | |||
| Message.created_at < current_page_first_message.created_at, | |||
| Message.id != current_page_first_message.id | |||
| Message.created_at < current_page_first_message.created_at, Message.id != current_page_first_message.id | |||
| ).count() | |||
| if rest_count > 0: | |||
| has_more = True | |||
| return InfiniteScrollPagination( | |||
| data=history_messages, | |||
| limit=limit, | |||
| has_more=has_more | |||
| ) | |||
| return InfiniteScrollPagination(data=history_messages, limit=limit, has_more=has_more) | |||
| @classmethod | |||
| def create_feedback(cls, app_model: App, message_id: str, user: Optional[Union[Account, EndUser]], | |||
| rating: Optional[str]) -> MessageFeedback: | |||
| def create_feedback( | |||
| cls, app_model: App, message_id: str, user: Optional[Union[Account, EndUser]], rating: Optional[str] | |||
| ) -> MessageFeedback: | |||
| if not user: | |||
| raise ValueError('user cannot be None') | |||
| raise ValueError("user cannot be None") | |||
| message = cls.get_message( | |||
| app_model=app_model, | |||
| user=user, | |||
| message_id=message_id | |||
| ) | |||
| message = cls.get_message(app_model=app_model, user=user, message_id=message_id) | |||
| feedback = message.user_feedback if isinstance(user, EndUser) else message.admin_feedback | |||
| @@ -148,14 +163,14 @@ class MessageService: | |||
| elif rating and feedback: | |||
| feedback.rating = rating | |||
| elif not rating and not feedback: | |||
| raise ValueError('rating cannot be None when feedback not exists') | |||
| raise ValueError("rating cannot be None when feedback not exists") | |||
| else: | |||
| feedback = MessageFeedback( | |||
| app_id=app_model.id, | |||
| conversation_id=message.conversation_id, | |||
| message_id=message.id, | |||
| rating=rating, | |||
| from_source=('user' if isinstance(user, EndUser) else 'admin'), | |||
| from_source=("user" if isinstance(user, EndUser) else "admin"), | |||
| from_end_user_id=(user.id if isinstance(user, EndUser) else None), | |||
| from_account_id=(user.id if isinstance(user, Account) else None), | |||
| ) | |||
| @@ -167,13 +182,17 @@ class MessageService: | |||
| @classmethod | |||
| def get_message(cls, app_model: App, user: Optional[Union[Account, EndUser]], message_id: str): | |||
| message = db.session.query(Message).filter( | |||
| Message.id == message_id, | |||
| Message.app_id == app_model.id, | |||
| Message.from_source == ('api' if isinstance(user, EndUser) else 'console'), | |||
| Message.from_end_user_id == (user.id if isinstance(user, EndUser) else None), | |||
| Message.from_account_id == (user.id if isinstance(user, Account) else None), | |||
| ).first() | |||
| message = ( | |||
| db.session.query(Message) | |||
| .filter( | |||
| Message.id == message_id, | |||
| Message.app_id == app_model.id, | |||
| Message.from_source == ("api" if isinstance(user, EndUser) else "console"), | |||
| Message.from_end_user_id == (user.id if isinstance(user, EndUser) else None), | |||
| Message.from_account_id == (user.id if isinstance(user, Account) else None), | |||
| ) | |||
| .first() | |||
| ) | |||
| if not message: | |||
| raise MessageNotExistsError() | |||
| @@ -181,27 +200,22 @@ class MessageService: | |||
| return message | |||
| @classmethod | |||
| def get_suggested_questions_after_answer(cls, app_model: App, user: Optional[Union[Account, EndUser]], | |||
| message_id: str, invoke_from: InvokeFrom) -> list[Message]: | |||
| def get_suggested_questions_after_answer( | |||
| cls, app_model: App, user: Optional[Union[Account, EndUser]], message_id: str, invoke_from: InvokeFrom | |||
| ) -> list[Message]: | |||
| if not user: | |||
| raise ValueError('user cannot be None') | |||
| raise ValueError("user cannot be None") | |||
| message = cls.get_message( | |||
| app_model=app_model, | |||
| user=user, | |||
| message_id=message_id | |||
| ) | |||
| message = cls.get_message(app_model=app_model, user=user, message_id=message_id) | |||
| conversation = ConversationService.get_conversation( | |||
| app_model=app_model, | |||
| conversation_id=message.conversation_id, | |||
| user=user | |||
| app_model=app_model, conversation_id=message.conversation_id, user=user | |||
| ) | |||
| if not conversation: | |||
| raise ConversationNotExistsError() | |||
| if conversation.status != 'normal': | |||
| if conversation.status != "normal": | |||
| raise ConversationCompletedError() | |||
| model_manager = ModelManager() | |||
| @@ -216,24 +230,23 @@ class MessageService: | |||
| if workflow is None: | |||
| return [] | |||
| app_config = AdvancedChatAppConfigManager.get_app_config( | |||
| app_model=app_model, | |||
| workflow=workflow | |||
| ) | |||
| app_config = AdvancedChatAppConfigManager.get_app_config(app_model=app_model, workflow=workflow) | |||
| if not app_config.additional_features.suggested_questions_after_answer: | |||
| raise SuggestedQuestionsAfterAnswerDisabledError() | |||
| model_instance = model_manager.get_default_model_instance( | |||
| tenant_id=app_model.tenant_id, | |||
| model_type=ModelType.LLM | |||
| tenant_id=app_model.tenant_id, model_type=ModelType.LLM | |||
| ) | |||
| else: | |||
| if not conversation.override_model_configs: | |||
| app_model_config = db.session.query(AppModelConfig).filter( | |||
| AppModelConfig.id == conversation.app_model_config_id, | |||
| AppModelConfig.app_id == app_model.id | |||
| ).first() | |||
| app_model_config = ( | |||
| db.session.query(AppModelConfig) | |||
| .filter( | |||
| AppModelConfig.id == conversation.app_model_config_id, AppModelConfig.app_id == app_model.id | |||
| ) | |||
| .first() | |||
| ) | |||
| else: | |||
| conversation_override_model_configs = json.loads(conversation.override_model_configs) | |||
| app_model_config = AppModelConfig( | |||
| @@ -249,16 +262,13 @@ class MessageService: | |||
| model_instance = model_manager.get_model_instance( | |||
| tenant_id=app_model.tenant_id, | |||
| provider=app_model_config.model_dict['provider'], | |||
| provider=app_model_config.model_dict["provider"], | |||
| model_type=ModelType.LLM, | |||
| model=app_model_config.model_dict['name'] | |||
| model=app_model_config.model_dict["name"], | |||
| ) | |||
| # get memory of conversation (read-only) | |||
| memory = TokenBufferMemory( | |||
| conversation=conversation, | |||
| model_instance=model_instance | |||
| ) | |||
| memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance) | |||
| histories = memory.get_history_prompt_text( | |||
| max_token_limit=3000, | |||
| @@ -267,18 +277,14 @@ class MessageService: | |||
| with measure_time() as timer: | |||
| questions = LLMGenerator.generate_suggested_questions_after_answer( | |||
| tenant_id=app_model.tenant_id, | |||
| histories=histories | |||
| tenant_id=app_model.tenant_id, histories=histories | |||
| ) | |||
| # get tracing instance | |||
| trace_manager = TraceQueueManager(app_id=app_model.id) | |||
| trace_manager.add_trace_task( | |||
| TraceTask( | |||
| TraceTaskName.SUGGESTED_QUESTION_TRACE, | |||
| message_id=message_id, | |||
| suggested_question=questions, | |||
| timer=timer | |||
| TraceTaskName.SUGGESTED_QUESTION_TRACE, message_id=message_id, suggested_question=questions, timer=timer | |||
| ) | |||
| ) | |||
| @@ -23,7 +23,6 @@ logger = logging.getLogger(__name__) | |||
| class ModelLoadBalancingService: | |||
| def __init__(self) -> None: | |||
| self.provider_manager = ProviderManager() | |||
| @@ -46,10 +45,7 @@ class ModelLoadBalancingService: | |||
| raise ValueError(f"Provider {provider} does not exist.") | |||
| # Enable model load balancing | |||
| provider_configuration.enable_model_load_balancing( | |||
| model=model, | |||
| model_type=ModelType.value_of(model_type) | |||
| ) | |||
| provider_configuration.enable_model_load_balancing(model=model, model_type=ModelType.value_of(model_type)) | |||
| def disable_model_load_balancing(self, tenant_id: str, provider: str, model: str, model_type: str) -> None: | |||
| """ | |||
| @@ -70,13 +66,11 @@ class ModelLoadBalancingService: | |||
| raise ValueError(f"Provider {provider} does not exist.") | |||
| # disable model load balancing | |||
| provider_configuration.disable_model_load_balancing( | |||
| model=model, | |||
| model_type=ModelType.value_of(model_type) | |||
| ) | |||
| provider_configuration.disable_model_load_balancing(model=model, model_type=ModelType.value_of(model_type)) | |||
| def get_load_balancing_configs(self, tenant_id: str, provider: str, model: str, model_type: str) \ | |||
| -> tuple[bool, list[dict]]: | |||
| def get_load_balancing_configs( | |||
| self, tenant_id: str, provider: str, model: str, model_type: str | |||
| ) -> tuple[bool, list[dict]]: | |||
| """ | |||
| Get load balancing configurations. | |||
| :param tenant_id: workspace id | |||
| @@ -107,20 +101,24 @@ class ModelLoadBalancingService: | |||
| is_load_balancing_enabled = True | |||
| # Get load balancing configurations | |||
| load_balancing_configs = db.session.query(LoadBalancingModelConfig) \ | |||
| load_balancing_configs = ( | |||
| db.session.query(LoadBalancingModelConfig) | |||
| .filter( | |||
| LoadBalancingModelConfig.tenant_id == tenant_id, | |||
| LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider, | |||
| LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(), | |||
| LoadBalancingModelConfig.model_name == model | |||
| ).order_by(LoadBalancingModelConfig.created_at).all() | |||
| LoadBalancingModelConfig.tenant_id == tenant_id, | |||
| LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider, | |||
| LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(), | |||
| LoadBalancingModelConfig.model_name == model, | |||
| ) | |||
| .order_by(LoadBalancingModelConfig.created_at) | |||
| .all() | |||
| ) | |||
| if provider_configuration.custom_configuration.provider: | |||
| # check if the inherit configuration exists, | |||
| # inherit is represented for the provider or model custom credentials | |||
| inherit_config_exists = False | |||
| for load_balancing_config in load_balancing_configs: | |||
| if load_balancing_config.name == '__inherit__': | |||
| if load_balancing_config.name == "__inherit__": | |||
| inherit_config_exists = True | |||
| break | |||
| @@ -133,7 +131,7 @@ class ModelLoadBalancingService: | |||
| else: | |||
| # move the inherit configuration to the first | |||
| for i, load_balancing_config in enumerate(load_balancing_configs[:]): | |||
| if load_balancing_config.name == '__inherit__': | |||
| if load_balancing_config.name == "__inherit__": | |||
| inherit_config = load_balancing_configs.pop(i) | |||
| load_balancing_configs.insert(0, inherit_config) | |||
| @@ -151,7 +149,7 @@ class ModelLoadBalancingService: | |||
| provider=provider, | |||
| model=model, | |||
| model_type=model_type, | |||
| config_id=load_balancing_config.id | |||
| config_id=load_balancing_config.id, | |||
| ) | |||
| try: | |||
| @@ -172,32 +170,32 @@ class ModelLoadBalancingService: | |||
| if variable in credentials: | |||
| try: | |||
| credentials[variable] = encrypter.decrypt_token_with_decoding( | |||
| credentials.get(variable), | |||
| decoding_rsa_key, | |||
| decoding_cipher_rsa | |||
| credentials.get(variable), decoding_rsa_key, decoding_cipher_rsa | |||
| ) | |||
| except ValueError: | |||
| pass | |||
| # Obfuscate credentials | |||
| credentials = provider_configuration.obfuscated_credentials( | |||
| credentials=credentials, | |||
| credential_form_schemas=credential_schemas.credential_form_schemas | |||
| credentials=credentials, credential_form_schemas=credential_schemas.credential_form_schemas | |||
| ) | |||
| datas.append({ | |||
| 'id': load_balancing_config.id, | |||
| 'name': load_balancing_config.name, | |||
| 'credentials': credentials, | |||
| 'enabled': load_balancing_config.enabled, | |||
| 'in_cooldown': in_cooldown, | |||
| 'ttl': ttl | |||
| }) | |||
| datas.append( | |||
| { | |||
| "id": load_balancing_config.id, | |||
| "name": load_balancing_config.name, | |||
| "credentials": credentials, | |||
| "enabled": load_balancing_config.enabled, | |||
| "in_cooldown": in_cooldown, | |||
| "ttl": ttl, | |||
| } | |||
| ) | |||
| return is_load_balancing_enabled, datas | |||
| def get_load_balancing_config(self, tenant_id: str, provider: str, model: str, model_type: str, config_id: str) \ | |||
| -> Optional[dict]: | |||
| def get_load_balancing_config( | |||
| self, tenant_id: str, provider: str, model: str, model_type: str, config_id: str | |||
| ) -> Optional[dict]: | |||
| """ | |||
| Get load balancing configuration. | |||
| :param tenant_id: workspace id | |||
| @@ -219,14 +217,17 @@ class ModelLoadBalancingService: | |||
| model_type = ModelType.value_of(model_type) | |||
| # Get load balancing configurations | |||
| load_balancing_model_config = db.session.query(LoadBalancingModelConfig) \ | |||
| load_balancing_model_config = ( | |||
| db.session.query(LoadBalancingModelConfig) | |||
| .filter( | |||
| LoadBalancingModelConfig.tenant_id == tenant_id, | |||
| LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider, | |||
| LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(), | |||
| LoadBalancingModelConfig.model_name == model, | |||
| LoadBalancingModelConfig.id == config_id | |||
| ).first() | |||
| LoadBalancingModelConfig.tenant_id == tenant_id, | |||
| LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider, | |||
| LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(), | |||
| LoadBalancingModelConfig.model_name == model, | |||
| LoadBalancingModelConfig.id == config_id, | |||
| ) | |||
| .first() | |||
| ) | |||
| if not load_balancing_model_config: | |||
| return None | |||
| @@ -244,19 +245,19 @@ class ModelLoadBalancingService: | |||
| # Obfuscate credentials | |||
| credentials = provider_configuration.obfuscated_credentials( | |||
| credentials=credentials, | |||
| credential_form_schemas=credential_schemas.credential_form_schemas | |||
| credentials=credentials, credential_form_schemas=credential_schemas.credential_form_schemas | |||
| ) | |||
| return { | |||
| 'id': load_balancing_model_config.id, | |||
| 'name': load_balancing_model_config.name, | |||
| 'credentials': credentials, | |||
| 'enabled': load_balancing_model_config.enabled | |||
| "id": load_balancing_model_config.id, | |||
| "name": load_balancing_model_config.name, | |||
| "credentials": credentials, | |||
| "enabled": load_balancing_model_config.enabled, | |||
| } | |||
| def _init_inherit_config(self, tenant_id: str, provider: str, model: str, model_type: ModelType) \ | |||
| -> LoadBalancingModelConfig: | |||
| def _init_inherit_config( | |||
| self, tenant_id: str, provider: str, model: str, model_type: ModelType | |||
| ) -> LoadBalancingModelConfig: | |||
| """ | |||
| Initialize the inherit configuration. | |||
| :param tenant_id: workspace id | |||
| @@ -271,18 +272,16 @@ class ModelLoadBalancingService: | |||
| provider_name=provider, | |||
| model_type=model_type.to_origin_model_type(), | |||
| model_name=model, | |||
| name='__inherit__' | |||
| name="__inherit__", | |||
| ) | |||
| db.session.add(inherit_config) | |||
| db.session.commit() | |||
| return inherit_config | |||
| def update_load_balancing_configs(self, tenant_id: str, | |||
| provider: str, | |||
| model: str, | |||
| model_type: str, | |||
| configs: list[dict]) -> None: | |||
| def update_load_balancing_configs( | |||
| self, tenant_id: str, provider: str, model: str, model_type: str, configs: list[dict] | |||
| ) -> None: | |||
| """ | |||
| Update load balancing configurations. | |||
| :param tenant_id: workspace id | |||
| @@ -304,15 +303,18 @@ class ModelLoadBalancingService: | |||
| model_type = ModelType.value_of(model_type) | |||
| if not isinstance(configs, list): | |||
| raise ValueError('Invalid load balancing configs') | |||
| raise ValueError("Invalid load balancing configs") | |||
| current_load_balancing_configs = db.session.query(LoadBalancingModelConfig) \ | |||
| current_load_balancing_configs = ( | |||
| db.session.query(LoadBalancingModelConfig) | |||
| .filter( | |||
| LoadBalancingModelConfig.tenant_id == tenant_id, | |||
| LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider, | |||
| LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(), | |||
| LoadBalancingModelConfig.model_name == model | |||
| ).all() | |||
| LoadBalancingModelConfig.tenant_id == tenant_id, | |||
| LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider, | |||
| LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(), | |||
| LoadBalancingModelConfig.model_name == model, | |||
| ) | |||
| .all() | |||
| ) | |||
| # id as key, config as value | |||
| current_load_balancing_configs_dict = {config.id: config for config in current_load_balancing_configs} | |||
| @@ -320,25 +322,25 @@ class ModelLoadBalancingService: | |||
| for config in configs: | |||
| if not isinstance(config, dict): | |||
| raise ValueError('Invalid load balancing config') | |||
| raise ValueError("Invalid load balancing config") | |||
| config_id = config.get('id') | |||
| name = config.get('name') | |||
| credentials = config.get('credentials') | |||
| enabled = config.get('enabled') | |||
| config_id = config.get("id") | |||
| name = config.get("name") | |||
| credentials = config.get("credentials") | |||
| enabled = config.get("enabled") | |||
| if not name: | |||
| raise ValueError('Invalid load balancing config name') | |||
| raise ValueError("Invalid load balancing config name") | |||
| if enabled is None: | |||
| raise ValueError('Invalid load balancing config enabled') | |||
| raise ValueError("Invalid load balancing config enabled") | |||
| # is config exists | |||
| if config_id: | |||
| config_id = str(config_id) | |||
| if config_id not in current_load_balancing_configs_dict: | |||
| raise ValueError('Invalid load balancing config id: {}'.format(config_id)) | |||
| raise ValueError("Invalid load balancing config id: {}".format(config_id)) | |||
| updated_config_ids.add(config_id) | |||
| @@ -347,11 +349,11 @@ class ModelLoadBalancingService: | |||
| # check duplicate name | |||
| for current_load_balancing_config in current_load_balancing_configs: | |||
| if current_load_balancing_config.id != config_id and current_load_balancing_config.name == name: | |||
| raise ValueError('Load balancing config name {} already exists'.format(name)) | |||
| raise ValueError("Load balancing config name {} already exists".format(name)) | |||
| if credentials: | |||
| if not isinstance(credentials, dict): | |||
| raise ValueError('Invalid load balancing config credentials') | |||
| raise ValueError("Invalid load balancing config credentials") | |||
| # validate custom provider config | |||
| credentials = self._custom_credentials_validate( | |||
| @@ -361,7 +363,7 @@ class ModelLoadBalancingService: | |||
| model=model, | |||
| credentials=credentials, | |||
| load_balancing_model_config=load_balancing_config, | |||
| validate=False | |||
| validate=False, | |||
| ) | |||
| # update load balancing config | |||
| @@ -375,19 +377,19 @@ class ModelLoadBalancingService: | |||
| self._clear_credentials_cache(tenant_id, config_id) | |||
| else: | |||
| # create load balancing config | |||
| if name == '__inherit__': | |||
| raise ValueError('Invalid load balancing config name') | |||
| if name == "__inherit__": | |||
| raise ValueError("Invalid load balancing config name") | |||
| # check duplicate name | |||
| for current_load_balancing_config in current_load_balancing_configs: | |||
| if current_load_balancing_config.name == name: | |||
| raise ValueError('Load balancing config name {} already exists'.format(name)) | |||
| raise ValueError("Load balancing config name {} already exists".format(name)) | |||
| if not credentials: | |||
| raise ValueError('Invalid load balancing config credentials') | |||
| raise ValueError("Invalid load balancing config credentials") | |||
| if not isinstance(credentials, dict): | |||
| raise ValueError('Invalid load balancing config credentials') | |||
| raise ValueError("Invalid load balancing config credentials") | |||
| # validate custom provider config | |||
| credentials = self._custom_credentials_validate( | |||
| @@ -396,7 +398,7 @@ class ModelLoadBalancingService: | |||
| model_type=model_type, | |||
| model=model, | |||
| credentials=credentials, | |||
| validate=False | |||
| validate=False, | |||
| ) | |||
| # create load balancing config | |||
| @@ -406,7 +408,7 @@ class ModelLoadBalancingService: | |||
| model_type=model_type.to_origin_model_type(), | |||
| model_name=model, | |||
| name=name, | |||
| encrypted_config=json.dumps(credentials) | |||
| encrypted_config=json.dumps(credentials), | |||
| ) | |||
| db.session.add(load_balancing_model_config) | |||
| @@ -420,12 +422,15 @@ class ModelLoadBalancingService: | |||
| self._clear_credentials_cache(tenant_id, config_id) | |||
| def validate_load_balancing_credentials(self, tenant_id: str, | |||
| provider: str, | |||
| model: str, | |||
| model_type: str, | |||
| credentials: dict, | |||
| config_id: Optional[str] = None) -> None: | |||
| def validate_load_balancing_credentials( | |||
| self, | |||
| tenant_id: str, | |||
| provider: str, | |||
| model: str, | |||
| model_type: str, | |||
| credentials: dict, | |||
| config_id: Optional[str] = None, | |||
| ) -> None: | |||
| """ | |||
| Validate load balancing credentials. | |||
| :param tenant_id: workspace id | |||
| @@ -450,14 +455,17 @@ class ModelLoadBalancingService: | |||
| load_balancing_model_config = None | |||
| if config_id: | |||
| # Get load balancing config | |||
| load_balancing_model_config = db.session.query(LoadBalancingModelConfig) \ | |||
| load_balancing_model_config = ( | |||
| db.session.query(LoadBalancingModelConfig) | |||
| .filter( | |||
| LoadBalancingModelConfig.tenant_id == tenant_id, | |||
| LoadBalancingModelConfig.provider_name == provider, | |||
| LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(), | |||
| LoadBalancingModelConfig.model_name == model, | |||
| LoadBalancingModelConfig.id == config_id | |||
| ).first() | |||
| LoadBalancingModelConfig.tenant_id == tenant_id, | |||
| LoadBalancingModelConfig.provider_name == provider, | |||
| LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(), | |||
| LoadBalancingModelConfig.model_name == model, | |||
| LoadBalancingModelConfig.id == config_id, | |||
| ) | |||
| .first() | |||
| ) | |||
| if not load_balancing_model_config: | |||
| raise ValueError(f"Load balancing config {config_id} does not exist.") | |||
| @@ -469,16 +477,19 @@ class ModelLoadBalancingService: | |||
| model_type=model_type, | |||
| model=model, | |||
| credentials=credentials, | |||
| load_balancing_model_config=load_balancing_model_config | |||
| load_balancing_model_config=load_balancing_model_config, | |||
| ) | |||
| def _custom_credentials_validate(self, tenant_id: str, | |||
| provider_configuration: ProviderConfiguration, | |||
| model_type: ModelType, | |||
| model: str, | |||
| credentials: dict, | |||
| load_balancing_model_config: Optional[LoadBalancingModelConfig] = None, | |||
| validate: bool = True) -> dict: | |||
| def _custom_credentials_validate( | |||
| self, | |||
| tenant_id: str, | |||
| provider_configuration: ProviderConfiguration, | |||
| model_type: ModelType, | |||
| model: str, | |||
| credentials: dict, | |||
| load_balancing_model_config: Optional[LoadBalancingModelConfig] = None, | |||
| validate: bool = True, | |||
| ) -> dict: | |||
| """ | |||
| Validate custom credentials. | |||
| :param tenant_id: workspace id | |||
| @@ -521,12 +532,11 @@ class ModelLoadBalancingService: | |||
| provider=provider_configuration.provider.provider, | |||
| model_type=model_type, | |||
| model=model, | |||
| credentials=credentials | |||
| credentials=credentials, | |||
| ) | |||
| else: | |||
| credentials = model_provider_factory.provider_credentials_validate( | |||
| provider=provider_configuration.provider.provider, | |||
| credentials=credentials | |||
| provider=provider_configuration.provider.provider, credentials=credentials | |||
| ) | |||
| for key, value in credentials.items(): | |||
| @@ -535,8 +545,9 @@ class ModelLoadBalancingService: | |||
| return credentials | |||
| def _get_credential_schema(self, provider_configuration: ProviderConfiguration) \ | |||
| -> ModelCredentialSchema | ProviderCredentialSchema: | |||
| def _get_credential_schema( | |||
| self, provider_configuration: ProviderConfiguration | |||
| ) -> ModelCredentialSchema | ProviderCredentialSchema: | |||
| """ | |||
| Get form schemas. | |||
| :param provider_configuration: provider configuration | |||
| @@ -558,9 +569,7 @@ class ModelLoadBalancingService: | |||
| :return: | |||
| """ | |||
| provider_model_credentials_cache = ProviderCredentialsCache( | |||
| tenant_id=tenant_id, | |||
| identity_id=config_id, | |||
| cache_type=ProviderCredentialsCacheType.LOAD_BALANCING_MODEL | |||
| tenant_id=tenant_id, identity_id=config_id, cache_type=ProviderCredentialsCacheType.LOAD_BALANCING_MODEL | |||
| ) | |||
| provider_model_credentials_cache.delete() | |||
| @@ -73,8 +73,8 @@ class ModelProviderService: | |||
| system_configuration=SystemConfigurationResponse( | |||
| enabled=provider_configuration.system_configuration.enabled, | |||
| current_quota_type=provider_configuration.system_configuration.current_quota_type, | |||
| quota_configurations=provider_configuration.system_configuration.quota_configurations | |||
| ) | |||
| quota_configurations=provider_configuration.system_configuration.quota_configurations, | |||
| ), | |||
| ) | |||
| provider_responses.append(provider_response) | |||
| @@ -95,9 +95,9 @@ class ModelProviderService: | |||
| provider_configurations = self.provider_manager.get_configurations(tenant_id) | |||
| # Get provider available models | |||
| return [ModelWithProviderEntityResponse(model) for model in provider_configurations.get_models( | |||
| provider=provider | |||
| )] | |||
| return [ | |||
| ModelWithProviderEntityResponse(model) for model in provider_configurations.get_models(provider=provider) | |||
| ] | |||
| def get_provider_credentials(self, tenant_id: str, provider: str) -> dict: | |||
| """ | |||
| @@ -195,13 +195,12 @@ class ModelProviderService: | |||
| # Get model custom credentials from ProviderModel if exists | |||
| return provider_configuration.get_custom_model_credentials( | |||
| model_type=ModelType.value_of(model_type), | |||
| model=model, | |||
| obfuscated=True | |||
| model_type=ModelType.value_of(model_type), model=model, obfuscated=True | |||
| ) | |||
| def model_credentials_validate(self, tenant_id: str, provider: str, model_type: str, model: str, | |||
| credentials: dict) -> None: | |||
| def model_credentials_validate( | |||
| self, tenant_id: str, provider: str, model_type: str, model: str, credentials: dict | |||
| ) -> None: | |||
| """ | |||
| validate model credentials. | |||
| @@ -222,13 +221,12 @@ class ModelProviderService: | |||
| # Validate model credentials | |||
| provider_configuration.custom_model_credentials_validate( | |||
| model_type=ModelType.value_of(model_type), | |||
| model=model, | |||
| credentials=credentials | |||
| model_type=ModelType.value_of(model_type), model=model, credentials=credentials | |||
| ) | |||
| def save_model_credentials(self, tenant_id: str, provider: str, model_type: str, model: str, | |||
| credentials: dict) -> None: | |||
| def save_model_credentials( | |||
| self, tenant_id: str, provider: str, model_type: str, model: str, credentials: dict | |||
| ) -> None: | |||
| """ | |||
| save model credentials. | |||
| @@ -249,9 +247,7 @@ class ModelProviderService: | |||
| # Add or update custom model credentials | |||
| provider_configuration.add_or_update_custom_model_credentials( | |||
| model_type=ModelType.value_of(model_type), | |||
| model=model, | |||
| credentials=credentials | |||
| model_type=ModelType.value_of(model_type), model=model, credentials=credentials | |||
| ) | |||
| def remove_model_credentials(self, tenant_id: str, provider: str, model_type: str, model: str) -> None: | |||
| @@ -273,10 +269,7 @@ class ModelProviderService: | |||
| raise ValueError(f"Provider {provider} does not exist.") | |||
| # Remove custom model credentials | |||
| provider_configuration.delete_custom_model_credentials( | |||
| model_type=ModelType.value_of(model_type), | |||
| model=model | |||
| ) | |||
| provider_configuration.delete_custom_model_credentials(model_type=ModelType.value_of(model_type), model=model) | |||
| def get_models_by_model_type(self, tenant_id: str, model_type: str) -> list[ProviderWithModelsResponse]: | |||
| """ | |||
| @@ -290,9 +283,7 @@ class ModelProviderService: | |||
| provider_configurations = self.provider_manager.get_configurations(tenant_id) | |||
| # Get provider available models | |||
| models = provider_configurations.get_models( | |||
| model_type=ModelType.value_of(model_type) | |||
| ) | |||
| models = provider_configurations.get_models(model_type=ModelType.value_of(model_type)) | |||
| # Group models by provider | |||
| provider_models = {} | |||
| @@ -323,16 +314,19 @@ class ModelProviderService: | |||
| icon_small=first_model.provider.icon_small, | |||
| icon_large=first_model.provider.icon_large, | |||
| status=CustomConfigurationStatus.ACTIVE, | |||
| models=[ProviderModelWithStatusEntity( | |||
| model=model.model, | |||
| label=model.label, | |||
| model_type=model.model_type, | |||
| features=model.features, | |||
| fetch_from=model.fetch_from, | |||
| model_properties=model.model_properties, | |||
| status=model.status, | |||
| load_balancing_enabled=model.load_balancing_enabled | |||
| ) for model in models] | |||
| models=[ | |||
| ProviderModelWithStatusEntity( | |||
| model=model.model, | |||
| label=model.label, | |||
| model_type=model.model_type, | |||
| features=model.features, | |||
| fetch_from=model.fetch_from, | |||
| model_properties=model.model_properties, | |||
| status=model.status, | |||
| load_balancing_enabled=model.load_balancing_enabled, | |||
| ) | |||
| for model in models | |||
| ], | |||
| ) | |||
| ) | |||
| @@ -361,19 +355,13 @@ class ModelProviderService: | |||
| model_type_instance = cast(LargeLanguageModel, model_type_instance) | |||
| # fetch credentials | |||
| credentials = provider_configuration.get_current_credentials( | |||
| model_type=ModelType.LLM, | |||
| model=model | |||
| ) | |||
| credentials = provider_configuration.get_current_credentials(model_type=ModelType.LLM, model=model) | |||
| if not credentials: | |||
| return [] | |||
| # Call get_parameter_rules method of model instance to get model parameter rules | |||
| return model_type_instance.get_parameter_rules( | |||
| model=model, | |||
| credentials=credentials | |||
| ) | |||
| return model_type_instance.get_parameter_rules(model=model, credentials=credentials) | |||
| def get_default_model_of_model_type(self, tenant_id: str, model_type: str) -> Optional[DefaultModelResponse]: | |||
| """ | |||
| @@ -384,22 +372,23 @@ class ModelProviderService: | |||
| :return: | |||
| """ | |||
| model_type_enum = ModelType.value_of(model_type) | |||
| result = self.provider_manager.get_default_model( | |||
| tenant_id=tenant_id, | |||
| model_type=model_type_enum | |||
| ) | |||
| result = self.provider_manager.get_default_model(tenant_id=tenant_id, model_type=model_type_enum) | |||
| try: | |||
| return DefaultModelResponse( | |||
| model=result.model, | |||
| model_type=result.model_type, | |||
| provider=SimpleProviderEntityResponse( | |||
| provider=result.provider.provider, | |||
| label=result.provider.label, | |||
| icon_small=result.provider.icon_small, | |||
| icon_large=result.provider.icon_large, | |||
| supported_model_types=result.provider.supported_model_types | |||
| return ( | |||
| DefaultModelResponse( | |||
| model=result.model, | |||
| model_type=result.model_type, | |||
| provider=SimpleProviderEntityResponse( | |||
| provider=result.provider.provider, | |||
| label=result.provider.label, | |||
| icon_small=result.provider.icon_small, | |||
| icon_large=result.provider.icon_large, | |||
| supported_model_types=result.provider.supported_model_types, | |||
| ), | |||
| ) | |||
| ) if result else None | |||
| if result | |||
| else None | |||
| ) | |||
| except Exception as e: | |||
| logger.info(f"get_default_model_of_model_type error: {e}") | |||
| return None | |||
| @@ -416,13 +405,12 @@ class ModelProviderService: | |||
| """ | |||
| model_type_enum = ModelType.value_of(model_type) | |||
| self.provider_manager.update_default_model_record( | |||
| tenant_id=tenant_id, | |||
| model_type=model_type_enum, | |||
| provider=provider, | |||
| model=model | |||
| tenant_id=tenant_id, model_type=model_type_enum, provider=provider, model=model | |||
| ) | |||
| def get_model_provider_icon(self, provider: str, icon_type: str, lang: str) -> tuple[Optional[bytes], Optional[str]]: | |||
| def get_model_provider_icon( | |||
| self, provider: str, icon_type: str, lang: str | |||
| ) -> tuple[Optional[bytes], Optional[str]]: | |||
| """ | |||
| get model provider icon. | |||
| @@ -434,11 +422,11 @@ class ModelProviderService: | |||
| provider_instance = model_provider_factory.get_provider_instance(provider) | |||
| provider_schema = provider_instance.get_provider_schema() | |||
| if icon_type.lower() == 'icon_small': | |||
| if icon_type.lower() == "icon_small": | |||
| if not provider_schema.icon_small: | |||
| raise ValueError(f"Provider {provider} does not have small icon.") | |||
| if lang.lower() == 'zh_hans': | |||
| if lang.lower() == "zh_hans": | |||
| file_name = provider_schema.icon_small.zh_Hans | |||
| else: | |||
| file_name = provider_schema.icon_small.en_US | |||
| @@ -446,13 +434,15 @@ class ModelProviderService: | |||
| if not provider_schema.icon_large: | |||
| raise ValueError(f"Provider {provider} does not have large icon.") | |||
| if lang.lower() == 'zh_hans': | |||
| if lang.lower() == "zh_hans": | |||
| file_name = provider_schema.icon_large.zh_Hans | |||
| else: | |||
| file_name = provider_schema.icon_large.en_US | |||
| root_path = current_app.root_path | |||
| provider_instance_path = os.path.dirname(os.path.join(root_path, provider_instance.__class__.__module__.replace('.', '/'))) | |||
| provider_instance_path = os.path.dirname( | |||
| os.path.join(root_path, provider_instance.__class__.__module__.replace(".", "/")) | |||
| ) | |||
| file_path = os.path.join(provider_instance_path, "_assets") | |||
| file_path = os.path.join(file_path, file_name) | |||
| @@ -460,10 +450,10 @@ class ModelProviderService: | |||
| return None, None | |||
| mimetype, _ = mimetypes.guess_type(file_path) | |||
| mimetype = mimetype or 'application/octet-stream' | |||
| mimetype = mimetype or "application/octet-stream" | |||
| # read binary from file | |||
| with open(file_path, 'rb') as f: | |||
| with open(file_path, "rb") as f: | |||
| byte_data = f.read() | |||
| return byte_data, mimetype | |||
| @@ -509,10 +499,7 @@ class ModelProviderService: | |||
| raise ValueError(f"Provider {provider} does not exist.") | |||
| # Enable model | |||
| provider_configuration.enable_model( | |||
| model=model, | |||
| model_type=ModelType.value_of(model_type) | |||
| ) | |||
| provider_configuration.enable_model(model=model, model_type=ModelType.value_of(model_type)) | |||
| def disable_model(self, tenant_id: str, provider: str, model: str, model_type: str) -> None: | |||
| """ | |||
| @@ -533,78 +520,49 @@ class ModelProviderService: | |||
| raise ValueError(f"Provider {provider} does not exist.") | |||
| # Enable model | |||
| provider_configuration.disable_model( | |||
| model=model, | |||
| model_type=ModelType.value_of(model_type) | |||
| ) | |||
| provider_configuration.disable_model(model=model, model_type=ModelType.value_of(model_type)) | |||
| def free_quota_submit(self, tenant_id: str, provider: str): | |||
| api_key = os.environ.get("FREE_QUOTA_APPLY_API_KEY") | |||
| api_base_url = os.environ.get("FREE_QUOTA_APPLY_BASE_URL") | |||
| api_url = api_base_url + '/api/v1/providers/apply' | |||
| api_url = api_base_url + "/api/v1/providers/apply" | |||
| headers = { | |||
| 'Content-Type': 'application/json', | |||
| 'Authorization': f"Bearer {api_key}" | |||
| } | |||
| response = requests.post(api_url, headers=headers, json={'workspace_id': tenant_id, 'provider_name': provider}) | |||
| headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"} | |||
| response = requests.post(api_url, headers=headers, json={"workspace_id": tenant_id, "provider_name": provider}) | |||
| if not response.ok: | |||
| logger.error(f"Request FREE QUOTA APPLY SERVER Error: {response.status_code} ") | |||
| raise ValueError(f"Error: {response.status_code} ") | |||
| if response.json()["code"] != 'success': | |||
| raise ValueError( | |||
| f"error: {response.json()['message']}" | |||
| ) | |||
| if response.json()["code"] != "success": | |||
| raise ValueError(f"error: {response.json()['message']}") | |||
| rst = response.json() | |||
| if rst['type'] == 'redirect': | |||
| return { | |||
| 'type': rst['type'], | |||
| 'redirect_url': rst['redirect_url'] | |||
| } | |||
| if rst["type"] == "redirect": | |||
| return {"type": rst["type"], "redirect_url": rst["redirect_url"]} | |||
| else: | |||
| return { | |||
| 'type': rst['type'], | |||
| 'result': 'success' | |||
| } | |||
| return {"type": rst["type"], "result": "success"} | |||
| def free_quota_qualification_verify(self, tenant_id: str, provider: str, token: Optional[str]): | |||
| api_key = os.environ.get("FREE_QUOTA_APPLY_API_KEY") | |||
| api_base_url = os.environ.get("FREE_QUOTA_APPLY_BASE_URL") | |||
| api_url = api_base_url + '/api/v1/providers/qualification-verify' | |||
| api_url = api_base_url + "/api/v1/providers/qualification-verify" | |||
| headers = { | |||
| 'Content-Type': 'application/json', | |||
| 'Authorization': f"Bearer {api_key}" | |||
| } | |||
| json_data = {'workspace_id': tenant_id, 'provider_name': provider} | |||
| headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"} | |||
| json_data = {"workspace_id": tenant_id, "provider_name": provider} | |||
| if token: | |||
| json_data['token'] = token | |||
| response = requests.post(api_url, headers=headers, | |||
| json=json_data) | |||
| json_data["token"] = token | |||
| response = requests.post(api_url, headers=headers, json=json_data) | |||
| if not response.ok: | |||
| logger.error(f"Request FREE QUOTA APPLY SERVER Error: {response.status_code} ") | |||
| raise ValueError(f"Error: {response.status_code} ") | |||
| rst = response.json() | |||
| if rst["code"] != 'success': | |||
| raise ValueError( | |||
| f"error: {rst['message']}" | |||
| ) | |||
| if rst["code"] != "success": | |||
| raise ValueError(f"error: {rst['message']}") | |||
| data = rst['data'] | |||
| if data['qualified'] is True: | |||
| return { | |||
| 'result': 'success', | |||
| 'provider_name': provider, | |||
| 'flag': True | |||
| } | |||
| data = rst["data"] | |||
| if data["qualified"] is True: | |||
| return {"result": "success", "provider_name": provider, "flag": True} | |||
| else: | |||
| return { | |||
| 'result': 'success', | |||
| 'provider_name': provider, | |||
| 'flag': False, | |||
| 'reason': data['reason'] | |||
| } | |||
| return {"result": "success", "provider_name": provider, "flag": False, "reason": data["reason"]} | |||
| @@ -4,17 +4,18 @@ from models.model import App, AppModelConfig | |||
| class ModerationService: | |||
| def moderation_for_outputs(self, app_id: str, app_model: App, text: str) -> ModerationOutputsResult: | |||
| app_model_config: AppModelConfig = None | |||
| app_model_config = db.session.query(AppModelConfig).filter(AppModelConfig.id == app_model.app_model_config_id).first() | |||
| app_model_config = ( | |||
| db.session.query(AppModelConfig).filter(AppModelConfig.id == app_model.app_model_config_id).first() | |||
| ) | |||
| if not app_model_config: | |||
| raise ValueError("app model config not found") | |||
| name = app_model_config.sensitive_word_avoidance_dict['type'] | |||
| config = app_model_config.sensitive_word_avoidance_dict['config'] | |||
| name = app_model_config.sensitive_word_avoidance_dict["type"] | |||
| config = app_model_config.sensitive_word_avoidance_dict["config"] | |||
| moderation = ModerationFactory(name, app_id, app_model.tenant_id, config) | |||
| return moderation.moderation_for_outputs(text) | |||
| @@ -4,15 +4,12 @@ import requests | |||
| class OperationService: | |||
| base_url = os.environ.get('BILLING_API_URL', 'BILLING_API_URL') | |||
| secret_key = os.environ.get('BILLING_API_SECRET_KEY', 'BILLING_API_SECRET_KEY') | |||
| base_url = os.environ.get("BILLING_API_URL", "BILLING_API_URL") | |||
| secret_key = os.environ.get("BILLING_API_SECRET_KEY", "BILLING_API_SECRET_KEY") | |||
| @classmethod | |||
| def _send_request(cls, method, endpoint, json=None, params=None): | |||
| headers = { | |||
| "Content-Type": "application/json", | |||
| "Billing-Api-Secret-Key": cls.secret_key | |||
| } | |||
| headers = {"Content-Type": "application/json", "Billing-Api-Secret-Key": cls.secret_key} | |||
| url = f"{cls.base_url}{endpoint}" | |||
| response = requests.request(method, url, json=json, params=params, headers=headers) | |||
| @@ -22,11 +19,11 @@ class OperationService: | |||
| @classmethod | |||
| def record_utm(cls, tenant_id: str, utm_info: dict): | |||
| params = { | |||
| 'tenant_id': tenant_id, | |||
| 'utm_source': utm_info.get('utm_source', ''), | |||
| 'utm_medium': utm_info.get('utm_medium', ''), | |||
| 'utm_campaign': utm_info.get('utm_campaign', ''), | |||
| 'utm_content': utm_info.get('utm_content', ''), | |||
| 'utm_term': utm_info.get('utm_term', '') | |||
| "tenant_id": tenant_id, | |||
| "utm_source": utm_info.get("utm_source", ""), | |||
| "utm_medium": utm_info.get("utm_medium", ""), | |||
| "utm_campaign": utm_info.get("utm_campaign", ""), | |||
| "utm_content": utm_info.get("utm_content", ""), | |||
| "utm_term": utm_info.get("utm_term", ""), | |||
| } | |||
| return cls._send_request('POST', '/tenant_utms', params=params) | |||
| return cls._send_request("POST", "/tenant_utms", params=params) | |||
| @@ -12,19 +12,25 @@ class OpsService: | |||
| :param tracing_provider: tracing provider | |||
| :return: | |||
| """ | |||
| trace_config_data: TraceAppConfig = db.session.query(TraceAppConfig).filter( | |||
| TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider | |||
| ).first() | |||
| trace_config_data: TraceAppConfig = ( | |||
| db.session.query(TraceAppConfig) | |||
| .filter(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider) | |||
| .first() | |||
| ) | |||
| if not trace_config_data: | |||
| return None | |||
| # decrypt_token and obfuscated_token | |||
| tenant_id = db.session.query(App).filter(App.id == app_id).first().tenant_id | |||
| decrypt_tracing_config = OpsTraceManager.decrypt_tracing_config(tenant_id, tracing_provider, trace_config_data.tracing_config) | |||
| if tracing_provider == 'langfuse' and ('project_key' not in decrypt_tracing_config or not decrypt_tracing_config.get('project_key')): | |||
| decrypt_tracing_config = OpsTraceManager.decrypt_tracing_config( | |||
| tenant_id, tracing_provider, trace_config_data.tracing_config | |||
| ) | |||
| if tracing_provider == "langfuse" and ( | |||
| "project_key" not in decrypt_tracing_config or not decrypt_tracing_config.get("project_key") | |||
| ): | |||
| project_key = OpsTraceManager.get_trace_config_project_key(decrypt_tracing_config, tracing_provider) | |||
| decrypt_tracing_config['project_key'] = project_key | |||
| decrypt_tracing_config["project_key"] = project_key | |||
| decrypt_tracing_config = OpsTraceManager.obfuscated_decrypt_token(tracing_provider, decrypt_tracing_config) | |||
| @@ -44,8 +50,10 @@ class OpsService: | |||
| if tracing_provider not in provider_config_map.keys() and tracing_provider: | |||
| return {"error": f"Invalid tracing provider: {tracing_provider}"} | |||
| config_class, other_keys = provider_config_map[tracing_provider]['config_class'], \ | |||
| provider_config_map[tracing_provider]['other_keys'] | |||
| config_class, other_keys = ( | |||
| provider_config_map[tracing_provider]["config_class"], | |||
| provider_config_map[tracing_provider]["other_keys"], | |||
| ) | |||
| default_config_instance = config_class(**tracing_config) | |||
| for key in other_keys: | |||
| if key in tracing_config and tracing_config[key] == "": | |||
| @@ -59,9 +67,11 @@ class OpsService: | |||
| project_key = OpsTraceManager.get_trace_config_project_key(tracing_config, tracing_provider) | |||
| # check if trace config already exists | |||
| trace_config_data: TraceAppConfig = db.session.query(TraceAppConfig).filter( | |||
| TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider | |||
| ).first() | |||
| trace_config_data: TraceAppConfig = ( | |||
| db.session.query(TraceAppConfig) | |||
| .filter(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider) | |||
| .first() | |||
| ) | |||
| if trace_config_data: | |||
| return None | |||
| @@ -69,8 +79,8 @@ class OpsService: | |||
| # get tenant id | |||
| tenant_id = db.session.query(App).filter(App.id == app_id).first().tenant_id | |||
| tracing_config = OpsTraceManager.encrypt_tracing_config(tenant_id, tracing_provider, tracing_config) | |||
| if tracing_provider == 'langfuse': | |||
| tracing_config['project_key'] = project_key | |||
| if tracing_provider == "langfuse": | |||
| tracing_config["project_key"] = project_key | |||
| trace_config_data = TraceAppConfig( | |||
| app_id=app_id, | |||
| tracing_provider=tracing_provider, | |||
| @@ -94,9 +104,11 @@ class OpsService: | |||
| raise ValueError(f"Invalid tracing provider: {tracing_provider}") | |||
| # check if trace config already exists | |||
| current_trace_config = db.session.query(TraceAppConfig).filter( | |||
| TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider | |||
| ).first() | |||
| current_trace_config = ( | |||
| db.session.query(TraceAppConfig) | |||
| .filter(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider) | |||
| .first() | |||
| ) | |||
| if not current_trace_config: | |||
| return None | |||
| @@ -126,9 +138,11 @@ class OpsService: | |||
| :param tracing_provider: tracing provider | |||
| :return: | |||
| """ | |||
| trace_config = db.session.query(TraceAppConfig).filter( | |||
| TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider | |||
| ).first() | |||
| trace_config = ( | |||
| db.session.query(TraceAppConfig) | |||
| .filter(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider) | |||
| .first() | |||
| ) | |||
| if not trace_config: | |||
| return None | |||
| @@ -16,7 +16,6 @@ logger = logging.getLogger(__name__) | |||
| class RecommendedAppService: | |||
| builtin_data: Optional[dict] = None | |||
| @classmethod | |||
| @@ -27,21 +26,21 @@ class RecommendedAppService: | |||
| :return: | |||
| """ | |||
| mode = dify_config.HOSTED_FETCH_APP_TEMPLATES_MODE | |||
| if mode == 'remote': | |||
| if mode == "remote": | |||
| try: | |||
| result = cls._fetch_recommended_apps_from_dify_official(language) | |||
| except Exception as e: | |||
| logger.warning(f'fetch recommended apps from dify official failed: {e}, switch to built-in.') | |||
| logger.warning(f"fetch recommended apps from dify official failed: {e}, switch to built-in.") | |||
| result = cls._fetch_recommended_apps_from_builtin(language) | |||
| elif mode == 'db': | |||
| elif mode == "db": | |||
| result = cls._fetch_recommended_apps_from_db(language) | |||
| elif mode == 'builtin': | |||
| elif mode == "builtin": | |||
| result = cls._fetch_recommended_apps_from_builtin(language) | |||
| else: | |||
| raise ValueError(f'invalid fetch recommended apps mode: {mode}') | |||
| raise ValueError(f"invalid fetch recommended apps mode: {mode}") | |||
| if not result.get('recommended_apps') and language != 'en-US': | |||
| result = cls._fetch_recommended_apps_from_builtin('en-US') | |||
| if not result.get("recommended_apps") and language != "en-US": | |||
| result = cls._fetch_recommended_apps_from_builtin("en-US") | |||
| return result | |||
| @@ -52,16 +51,18 @@ class RecommendedAppService: | |||
| :param language: language | |||
| :return: | |||
| """ | |||
| recommended_apps = db.session.query(RecommendedApp).filter( | |||
| RecommendedApp.is_listed == True, | |||
| RecommendedApp.language == language | |||
| ).all() | |||
| recommended_apps = ( | |||
| db.session.query(RecommendedApp) | |||
| .filter(RecommendedApp.is_listed == True, RecommendedApp.language == language) | |||
| .all() | |||
| ) | |||
| if len(recommended_apps) == 0: | |||
| recommended_apps = db.session.query(RecommendedApp).filter( | |||
| RecommendedApp.is_listed == True, | |||
| RecommendedApp.language == languages[0] | |||
| ).all() | |||
| recommended_apps = ( | |||
| db.session.query(RecommendedApp) | |||
| .filter(RecommendedApp.is_listed == True, RecommendedApp.language == languages[0]) | |||
| .all() | |||
| ) | |||
| categories = set() | |||
| recommended_apps_result = [] | |||
| @@ -75,28 +76,28 @@ class RecommendedAppService: | |||
| continue | |||
| recommended_app_result = { | |||
| 'id': recommended_app.id, | |||
| 'app': { | |||
| 'id': app.id, | |||
| 'name': app.name, | |||
| 'mode': app.mode, | |||
| 'icon': app.icon, | |||
| 'icon_background': app.icon_background | |||
| "id": recommended_app.id, | |||
| "app": { | |||
| "id": app.id, | |||
| "name": app.name, | |||
| "mode": app.mode, | |||
| "icon": app.icon, | |||
| "icon_background": app.icon_background, | |||
| }, | |||
| 'app_id': recommended_app.app_id, | |||
| 'description': site.description, | |||
| 'copyright': site.copyright, | |||
| 'privacy_policy': site.privacy_policy, | |||
| 'custom_disclaimer': site.custom_disclaimer, | |||
| 'category': recommended_app.category, | |||
| 'position': recommended_app.position, | |||
| 'is_listed': recommended_app.is_listed | |||
| "app_id": recommended_app.app_id, | |||
| "description": site.description, | |||
| "copyright": site.copyright, | |||
| "privacy_policy": site.privacy_policy, | |||
| "custom_disclaimer": site.custom_disclaimer, | |||
| "category": recommended_app.category, | |||
| "position": recommended_app.position, | |||
| "is_listed": recommended_app.is_listed, | |||
| } | |||
| recommended_apps_result.append(recommended_app_result) | |||
| categories.add(recommended_app.category) # add category to categories | |||
| return {'recommended_apps': recommended_apps_result, 'categories': sorted(categories)} | |||
| return {"recommended_apps": recommended_apps_result, "categories": sorted(categories)} | |||
| @classmethod | |||
| def _fetch_recommended_apps_from_dify_official(cls, language: str) -> dict: | |||
| @@ -106,16 +107,16 @@ class RecommendedAppService: | |||
| :return: | |||
| """ | |||
| domain = dify_config.HOSTED_FETCH_APP_TEMPLATES_REMOTE_DOMAIN | |||
| url = f'{domain}/apps?language={language}' | |||
| url = f"{domain}/apps?language={language}" | |||
| response = requests.get(url, timeout=(3, 10)) | |||
| if response.status_code != 200: | |||
| raise ValueError(f'fetch recommended apps failed, status code: {response.status_code}') | |||
| raise ValueError(f"fetch recommended apps failed, status code: {response.status_code}") | |||
| result = response.json() | |||
| if "categories" in result: | |||
| result["categories"] = sorted(result["categories"]) | |||
| return result | |||
| @classmethod | |||
| @@ -126,7 +127,7 @@ class RecommendedAppService: | |||
| :return: | |||
| """ | |||
| builtin_data = cls._get_builtin_data() | |||
| return builtin_data.get('recommended_apps', {}).get(language) | |||
| return builtin_data.get("recommended_apps", {}).get(language) | |||
| @classmethod | |||
| def get_recommend_app_detail(cls, app_id: str) -> Optional[dict]: | |||
| @@ -136,18 +137,18 @@ class RecommendedAppService: | |||
| :return: | |||
| """ | |||
| mode = dify_config.HOSTED_FETCH_APP_TEMPLATES_MODE | |||
| if mode == 'remote': | |||
| if mode == "remote": | |||
| try: | |||
| result = cls._fetch_recommended_app_detail_from_dify_official(app_id) | |||
| except Exception as e: | |||
| logger.warning(f'fetch recommended app detail from dify official failed: {e}, switch to built-in.') | |||
| logger.warning(f"fetch recommended app detail from dify official failed: {e}, switch to built-in.") | |||
| result = cls._fetch_recommended_app_detail_from_builtin(app_id) | |||
| elif mode == 'db': | |||
| elif mode == "db": | |||
| result = cls._fetch_recommended_app_detail_from_db(app_id) | |||
| elif mode == 'builtin': | |||
| elif mode == "builtin": | |||
| result = cls._fetch_recommended_app_detail_from_builtin(app_id) | |||
| else: | |||
| raise ValueError(f'invalid fetch recommended app detail mode: {mode}') | |||
| raise ValueError(f"invalid fetch recommended app detail mode: {mode}") | |||
| return result | |||
| @@ -159,7 +160,7 @@ class RecommendedAppService: | |||
| :return: | |||
| """ | |||
| domain = dify_config.HOSTED_FETCH_APP_TEMPLATES_REMOTE_DOMAIN | |||
| url = f'{domain}/apps/{app_id}' | |||
| url = f"{domain}/apps/{app_id}" | |||
| response = requests.get(url, timeout=(3, 10)) | |||
| if response.status_code != 200: | |||
| return None | |||
| @@ -174,10 +175,11 @@ class RecommendedAppService: | |||
| :return: | |||
| """ | |||
| # is in public recommended list | |||
| recommended_app = db.session.query(RecommendedApp).filter( | |||
| RecommendedApp.is_listed == True, | |||
| RecommendedApp.app_id == app_id | |||
| ).first() | |||
| recommended_app = ( | |||
| db.session.query(RecommendedApp) | |||
| .filter(RecommendedApp.is_listed == True, RecommendedApp.app_id == app_id) | |||
| .first() | |||
| ) | |||
| if not recommended_app: | |||
| return None | |||
| @@ -188,12 +190,12 @@ class RecommendedAppService: | |||
| return None | |||
| return { | |||
| 'id': app_model.id, | |||
| 'name': app_model.name, | |||
| 'icon': app_model.icon, | |||
| 'icon_background': app_model.icon_background, | |||
| 'mode': app_model.mode, | |||
| 'export_data': AppDslService.export_dsl(app_model=app_model) | |||
| "id": app_model.id, | |||
| "name": app_model.name, | |||
| "icon": app_model.icon, | |||
| "icon_background": app_model.icon_background, | |||
| "mode": app_model.mode, | |||
| "export_data": AppDslService.export_dsl(app_model=app_model), | |||
| } | |||
| @classmethod | |||
| @@ -204,7 +206,7 @@ class RecommendedAppService: | |||
| :return: | |||
| """ | |||
| builtin_data = cls._get_builtin_data() | |||
| return builtin_data.get('app_details', {}).get(app_id) | |||
| return builtin_data.get("app_details", {}).get(app_id) | |||
| @classmethod | |||
| def _get_builtin_data(cls) -> dict: | |||
| @@ -216,7 +218,7 @@ class RecommendedAppService: | |||
| return cls.builtin_data | |||
| root_path = current_app.root_path | |||
| with open(path.join(root_path, 'constants', 'recommended_apps.json'), encoding='utf-8') as f: | |||
| with open(path.join(root_path, "constants", "recommended_apps.json"), encoding="utf-8") as f: | |||
| json_data = f.read() | |||
| data = json.loads(json_data) | |||
| cls.builtin_data = data | |||
| @@ -229,27 +231,24 @@ class RecommendedAppService: | |||
| Fetch all recommended apps and export datas | |||
| :return: | |||
| """ | |||
| templates = { | |||
| "recommended_apps": {}, | |||
| "app_details": {} | |||
| } | |||
| templates = {"recommended_apps": {}, "app_details": {}} | |||
| for language in languages: | |||
| try: | |||
| result = cls._fetch_recommended_apps_from_dify_official(language) | |||
| except Exception as e: | |||
| logger.warning(f'fetch recommended apps from dify official failed: {e}, skip.') | |||
| logger.warning(f"fetch recommended apps from dify official failed: {e}, skip.") | |||
| continue | |||
| templates['recommended_apps'][language] = result | |||
| templates["recommended_apps"][language] = result | |||
| for recommended_app in result.get('recommended_apps'): | |||
| app_id = recommended_app.get('app_id') | |||
| for recommended_app in result.get("recommended_apps"): | |||
| app_id = recommended_app.get("app_id") | |||
| # get app detail | |||
| app_detail = cls._fetch_recommended_app_detail_from_dify_official(app_id) | |||
| if not app_detail: | |||
| continue | |||
| templates['app_details'][app_id] = app_detail | |||
| templates["app_details"][app_id] = app_detail | |||
| return templates | |||
| @@ -10,46 +10,48 @@ from services.message_service import MessageService | |||
| class SavedMessageService: | |||
| @classmethod | |||
| def pagination_by_last_id(cls, app_model: App, user: Optional[Union[Account, EndUser]], | |||
| last_id: Optional[str], limit: int) -> InfiniteScrollPagination: | |||
| saved_messages = db.session.query(SavedMessage).filter( | |||
| SavedMessage.app_id == app_model.id, | |||
| SavedMessage.created_by_role == ('account' if isinstance(user, Account) else 'end_user'), | |||
| SavedMessage.created_by == user.id | |||
| ).order_by(SavedMessage.created_at.desc()).all() | |||
| def pagination_by_last_id( | |||
| cls, app_model: App, user: Optional[Union[Account, EndUser]], last_id: Optional[str], limit: int | |||
| ) -> InfiniteScrollPagination: | |||
| saved_messages = ( | |||
| db.session.query(SavedMessage) | |||
| .filter( | |||
| SavedMessage.app_id == app_model.id, | |||
| SavedMessage.created_by_role == ("account" if isinstance(user, Account) else "end_user"), | |||
| SavedMessage.created_by == user.id, | |||
| ) | |||
| .order_by(SavedMessage.created_at.desc()) | |||
| .all() | |||
| ) | |||
| message_ids = [sm.message_id for sm in saved_messages] | |||
| return MessageService.pagination_by_last_id( | |||
| app_model=app_model, | |||
| user=user, | |||
| last_id=last_id, | |||
| limit=limit, | |||
| include_ids=message_ids | |||
| app_model=app_model, user=user, last_id=last_id, limit=limit, include_ids=message_ids | |||
| ) | |||
| @classmethod | |||
| def save(cls, app_model: App, user: Optional[Union[Account, EndUser]], message_id: str): | |||
| saved_message = db.session.query(SavedMessage).filter( | |||
| SavedMessage.app_id == app_model.id, | |||
| SavedMessage.message_id == message_id, | |||
| SavedMessage.created_by_role == ('account' if isinstance(user, Account) else 'end_user'), | |||
| SavedMessage.created_by == user.id | |||
| ).first() | |||
| saved_message = ( | |||
| db.session.query(SavedMessage) | |||
| .filter( | |||
| SavedMessage.app_id == app_model.id, | |||
| SavedMessage.message_id == message_id, | |||
| SavedMessage.created_by_role == ("account" if isinstance(user, Account) else "end_user"), | |||
| SavedMessage.created_by == user.id, | |||
| ) | |||
| .first() | |||
| ) | |||
| if saved_message: | |||
| return | |||
| message = MessageService.get_message( | |||
| app_model=app_model, | |||
| user=user, | |||
| message_id=message_id | |||
| ) | |||
| message = MessageService.get_message(app_model=app_model, user=user, message_id=message_id) | |||
| saved_message = SavedMessage( | |||
| app_id=app_model.id, | |||
| message_id=message.id, | |||
| created_by_role='account' if isinstance(user, Account) else 'end_user', | |||
| created_by=user.id | |||
| created_by_role="account" if isinstance(user, Account) else "end_user", | |||
| created_by=user.id, | |||
| ) | |||
| db.session.add(saved_message) | |||
| @@ -57,12 +59,16 @@ class SavedMessageService: | |||
| @classmethod | |||
| def delete(cls, app_model: App, user: Optional[Union[Account, EndUser]], message_id: str): | |||
| saved_message = db.session.query(SavedMessage).filter( | |||
| SavedMessage.app_id == app_model.id, | |||
| SavedMessage.message_id == message_id, | |||
| SavedMessage.created_by_role == ('account' if isinstance(user, Account) else 'end_user'), | |||
| SavedMessage.created_by == user.id | |||
| ).first() | |||
| saved_message = ( | |||
| db.session.query(SavedMessage) | |||
| .filter( | |||
| SavedMessage.app_id == app_model.id, | |||
| SavedMessage.message_id == message_id, | |||
| SavedMessage.created_by_role == ("account" if isinstance(user, Account) else "end_user"), | |||
| SavedMessage.created_by == user.id, | |||
| ) | |||
| .first() | |||
| ) | |||
| if not saved_message: | |||
| return | |||
| @@ -12,38 +12,32 @@ from models.model import App, Tag, TagBinding | |||
| class TagService: | |||
| @staticmethod | |||
| def get_tags(tag_type: str, current_tenant_id: str, keyword: str = None) -> list: | |||
| query = db.session.query( | |||
| Tag.id, Tag.type, Tag.name, func.count(TagBinding.id).label('binding_count') | |||
| ).outerjoin( | |||
| TagBinding, Tag.id == TagBinding.tag_id | |||
| ).filter( | |||
| Tag.type == tag_type, | |||
| Tag.tenant_id == current_tenant_id | |||
| query = ( | |||
| db.session.query(Tag.id, Tag.type, Tag.name, func.count(TagBinding.id).label("binding_count")) | |||
| .outerjoin(TagBinding, Tag.id == TagBinding.tag_id) | |||
| .filter(Tag.type == tag_type, Tag.tenant_id == current_tenant_id) | |||
| ) | |||
| if keyword: | |||
| query = query.filter(db.and_(Tag.name.ilike(f'%{keyword}%'))) | |||
| query = query.group_by( | |||
| Tag.id | |||
| ) | |||
| query = query.filter(db.and_(Tag.name.ilike(f"%{keyword}%"))) | |||
| query = query.group_by(Tag.id) | |||
| results = query.order_by(Tag.created_at.desc()).all() | |||
| return results | |||
| @staticmethod | |||
| def get_target_ids_by_tag_ids(tag_type: str, current_tenant_id: str, tag_ids: list) -> list: | |||
| tags = db.session.query(Tag).filter( | |||
| Tag.id.in_(tag_ids), | |||
| Tag.tenant_id == current_tenant_id, | |||
| Tag.type == tag_type | |||
| ).all() | |||
| tags = ( | |||
| db.session.query(Tag) | |||
| .filter(Tag.id.in_(tag_ids), Tag.tenant_id == current_tenant_id, Tag.type == tag_type) | |||
| .all() | |||
| ) | |||
| if not tags: | |||
| return [] | |||
| tag_ids = [tag.id for tag in tags] | |||
| tag_bindings = db.session.query( | |||
| TagBinding.target_id | |||
| ).filter( | |||
| TagBinding.tag_id.in_(tag_ids), | |||
| TagBinding.tenant_id == current_tenant_id | |||
| ).all() | |||
| tag_bindings = ( | |||
| db.session.query(TagBinding.target_id) | |||
| .filter(TagBinding.tag_id.in_(tag_ids), TagBinding.tenant_id == current_tenant_id) | |||
| .all() | |||
| ) | |||
| if not tag_bindings: | |||
| return [] | |||
| results = [tag_binding.target_id for tag_binding in tag_bindings] | |||
| @@ -51,27 +45,28 @@ class TagService: | |||
| @staticmethod | |||
| def get_tags_by_target_id(tag_type: str, current_tenant_id: str, target_id: str) -> list: | |||
| tags = db.session.query(Tag).join( | |||
| TagBinding, | |||
| Tag.id == TagBinding.tag_id | |||
| ).filter( | |||
| TagBinding.target_id == target_id, | |||
| TagBinding.tenant_id == current_tenant_id, | |||
| Tag.tenant_id == current_tenant_id, | |||
| Tag.type == tag_type | |||
| ).all() | |||
| tags = ( | |||
| db.session.query(Tag) | |||
| .join(TagBinding, Tag.id == TagBinding.tag_id) | |||
| .filter( | |||
| TagBinding.target_id == target_id, | |||
| TagBinding.tenant_id == current_tenant_id, | |||
| Tag.tenant_id == current_tenant_id, | |||
| Tag.type == tag_type, | |||
| ) | |||
| .all() | |||
| ) | |||
| return tags if tags else [] | |||
| @staticmethod | |||
| def save_tags(args: dict) -> Tag: | |||
| tag = Tag( | |||
| id=str(uuid.uuid4()), | |||
| name=args['name'], | |||
| type=args['type'], | |||
| name=args["name"], | |||
| type=args["type"], | |||
| created_by=current_user.id, | |||
| tenant_id=current_user.current_tenant_id | |||
| tenant_id=current_user.current_tenant_id, | |||
| ) | |||
| db.session.add(tag) | |||
| db.session.commit() | |||
| @@ -82,7 +77,7 @@ class TagService: | |||
| tag = db.session.query(Tag).filter(Tag.id == tag_id).first() | |||
| if not tag: | |||
| raise NotFound("Tag not found") | |||
| tag.name = args['name'] | |||
| tag.name = args["name"] | |||
| db.session.commit() | |||
| return tag | |||
| @@ -107,20 +102,21 @@ class TagService: | |||
| @staticmethod | |||
| def save_tag_binding(args): | |||
| # check if target exists | |||
| TagService.check_target_exists(args['type'], args['target_id']) | |||
| TagService.check_target_exists(args["type"], args["target_id"]) | |||
| # save tag binding | |||
| for tag_id in args['tag_ids']: | |||
| tag_binding = db.session.query(TagBinding).filter( | |||
| TagBinding.tag_id == tag_id, | |||
| TagBinding.target_id == args['target_id'] | |||
| ).first() | |||
| for tag_id in args["tag_ids"]: | |||
| tag_binding = ( | |||
| db.session.query(TagBinding) | |||
| .filter(TagBinding.tag_id == tag_id, TagBinding.target_id == args["target_id"]) | |||
| .first() | |||
| ) | |||
| if tag_binding: | |||
| continue | |||
| new_tag_binding = TagBinding( | |||
| tag_id=tag_id, | |||
| target_id=args['target_id'], | |||
| target_id=args["target_id"], | |||
| tenant_id=current_user.current_tenant_id, | |||
| created_by=current_user.id | |||
| created_by=current_user.id, | |||
| ) | |||
| db.session.add(new_tag_binding) | |||
| db.session.commit() | |||
| @@ -128,34 +124,34 @@ class TagService: | |||
| @staticmethod | |||
| def delete_tag_binding(args): | |||
| # check if target exists | |||
| TagService.check_target_exists(args['type'], args['target_id']) | |||
| TagService.check_target_exists(args["type"], args["target_id"]) | |||
| # delete tag binding | |||
| tag_bindings = db.session.query(TagBinding).filter( | |||
| TagBinding.target_id == args['target_id'], | |||
| TagBinding.tag_id == (args['tag_id']) | |||
| ).first() | |||
| tag_bindings = ( | |||
| db.session.query(TagBinding) | |||
| .filter(TagBinding.target_id == args["target_id"], TagBinding.tag_id == (args["tag_id"])) | |||
| .first() | |||
| ) | |||
| if tag_bindings: | |||
| db.session.delete(tag_bindings) | |||
| db.session.commit() | |||
| @staticmethod | |||
| def check_target_exists(type: str, target_id: str): | |||
| if type == 'knowledge': | |||
| dataset = db.session.query(Dataset).filter( | |||
| Dataset.tenant_id == current_user.current_tenant_id, | |||
| Dataset.id == target_id | |||
| ).first() | |||
| if type == "knowledge": | |||
| dataset = ( | |||
| db.session.query(Dataset) | |||
| .filter(Dataset.tenant_id == current_user.current_tenant_id, Dataset.id == target_id) | |||
| .first() | |||
| ) | |||
| if not dataset: | |||
| raise NotFound("Dataset not found") | |||
| elif type == 'app': | |||
| app = db.session.query(App).filter( | |||
| App.tenant_id == current_user.current_tenant_id, | |||
| App.id == target_id | |||
| ).first() | |||
| elif type == "app": | |||
| app = ( | |||
| db.session.query(App) | |||
| .filter(App.tenant_id == current_user.current_tenant_id, App.id == target_id) | |||
| .first() | |||
| ) | |||
| if not app: | |||
| raise NotFound("App not found") | |||
| else: | |||
| raise NotFound("Invalid binding type") | |||
| @@ -29,111 +29,107 @@ class ApiToolManageService: | |||
| @staticmethod | |||
| def parser_api_schema(schema: str) -> list[ApiToolBundle]: | |||
| """ | |||
| parse api schema to tool bundle | |||
| parse api schema to tool bundle | |||
| """ | |||
| try: | |||
| warnings = {} | |||
| try: | |||
| tool_bundles, schema_type = ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(schema, warning=warnings) | |||
| except Exception as e: | |||
| raise ValueError(f'invalid schema: {str(e)}') | |||
| raise ValueError(f"invalid schema: {str(e)}") | |||
| credentials_schema = [ | |||
| ToolProviderCredentials( | |||
| name='auth_type', | |||
| name="auth_type", | |||
| type=ToolProviderCredentials.CredentialsType.SELECT, | |||
| required=True, | |||
| default='none', | |||
| default="none", | |||
| options=[ | |||
| ToolCredentialsOption(value='none', label=I18nObject( | |||
| en_US='None', | |||
| zh_Hans='无' | |||
| )), | |||
| ToolCredentialsOption(value='api_key', label=I18nObject( | |||
| en_US='Api Key', | |||
| zh_Hans='Api Key' | |||
| )), | |||
| ToolCredentialsOption(value="none", label=I18nObject(en_US="None", zh_Hans="无")), | |||
| ToolCredentialsOption(value="api_key", label=I18nObject(en_US="Api Key", zh_Hans="Api Key")), | |||
| ], | |||
| placeholder=I18nObject( | |||
| en_US='Select auth type', | |||
| zh_Hans='选择认证方式' | |||
| ) | |||
| placeholder=I18nObject(en_US="Select auth type", zh_Hans="选择认证方式"), | |||
| ), | |||
| ToolProviderCredentials( | |||
| name='api_key_header', | |||
| name="api_key_header", | |||
| type=ToolProviderCredentials.CredentialsType.TEXT_INPUT, | |||
| required=False, | |||
| placeholder=I18nObject( | |||
| en_US='Enter api key header', | |||
| zh_Hans='输入 api key header,如:X-API-KEY' | |||
| ), | |||
| default='api_key', | |||
| help=I18nObject( | |||
| en_US='HTTP header name for api key', | |||
| zh_Hans='HTTP 头部字段名,用于传递 api key' | |||
| ) | |||
| placeholder=I18nObject(en_US="Enter api key header", zh_Hans="输入 api key header,如:X-API-KEY"), | |||
| default="api_key", | |||
| help=I18nObject(en_US="HTTP header name for api key", zh_Hans="HTTP 头部字段名,用于传递 api key"), | |||
| ), | |||
| ToolProviderCredentials( | |||
| name='api_key_value', | |||
| name="api_key_value", | |||
| type=ToolProviderCredentials.CredentialsType.TEXT_INPUT, | |||
| required=False, | |||
| placeholder=I18nObject( | |||
| en_US='Enter api key', | |||
| zh_Hans='输入 api key' | |||
| ), | |||
| default='' | |||
| placeholder=I18nObject(en_US="Enter api key", zh_Hans="输入 api key"), | |||
| default="", | |||
| ), | |||
| ] | |||
| return jsonable_encoder({ | |||
| 'schema_type': schema_type, | |||
| 'parameters_schema': tool_bundles, | |||
| 'credentials_schema': credentials_schema, | |||
| 'warning': warnings | |||
| }) | |||
| return jsonable_encoder( | |||
| { | |||
| "schema_type": schema_type, | |||
| "parameters_schema": tool_bundles, | |||
| "credentials_schema": credentials_schema, | |||
| "warning": warnings, | |||
| } | |||
| ) | |||
| except Exception as e: | |||
| raise ValueError(f'invalid schema: {str(e)}') | |||
| raise ValueError(f"invalid schema: {str(e)}") | |||
| @staticmethod | |||
| def convert_schema_to_tool_bundles(schema: str, extra_info: dict = None) -> list[ApiToolBundle]: | |||
| """ | |||
| convert schema to tool bundles | |||
| convert schema to tool bundles | |||
| :return: the list of tool bundles, description | |||
| :return: the list of tool bundles, description | |||
| """ | |||
| try: | |||
| tool_bundles = ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(schema, extra_info=extra_info) | |||
| return tool_bundles | |||
| except Exception as e: | |||
| raise ValueError(f'invalid schema: {str(e)}') | |||
| raise ValueError(f"invalid schema: {str(e)}") | |||
| @staticmethod | |||
| def create_api_tool_provider( | |||
| user_id: str, tenant_id: str, provider_name: str, icon: dict, credentials: dict, | |||
| schema_type: str, schema: str, privacy_policy: str, custom_disclaimer: str, labels: list[str] | |||
| user_id: str, | |||
| tenant_id: str, | |||
| provider_name: str, | |||
| icon: dict, | |||
| credentials: dict, | |||
| schema_type: str, | |||
| schema: str, | |||
| privacy_policy: str, | |||
| custom_disclaimer: str, | |||
| labels: list[str], | |||
| ): | |||
| """ | |||
| create api tool provider | |||
| create api tool provider | |||
| """ | |||
| if schema_type not in [member.value for member in ApiProviderSchemaType]: | |||
| raise ValueError(f'invalid schema type {schema}') | |||
| raise ValueError(f"invalid schema type {schema}") | |||
| # check if the provider exists | |||
| provider: ApiToolProvider = db.session.query(ApiToolProvider).filter( | |||
| ApiToolProvider.tenant_id == tenant_id, | |||
| ApiToolProvider.name == provider_name, | |||
| ).first() | |||
| provider: ApiToolProvider = ( | |||
| db.session.query(ApiToolProvider) | |||
| .filter( | |||
| ApiToolProvider.tenant_id == tenant_id, | |||
| ApiToolProvider.name == provider_name, | |||
| ) | |||
| .first() | |||
| ) | |||
| if provider is not None: | |||
| raise ValueError(f'provider {provider_name} already exists') | |||
| raise ValueError(f"provider {provider_name} already exists") | |||
| # parse openapi to tool bundle | |||
| extra_info = {} | |||
| # extra info like description will be set here | |||
| tool_bundles, schema_type = ApiToolManageService.convert_schema_to_tool_bundles(schema, extra_info) | |||
| if len(tool_bundles) > 100: | |||
| raise ValueError('the number of apis should be less than 100') | |||
| raise ValueError("the number of apis should be less than 100") | |||
| # create db provider | |||
| db_provider = ApiToolProvider( | |||
| @@ -142,19 +138,19 @@ class ApiToolManageService: | |||
| name=provider_name, | |||
| icon=json.dumps(icon), | |||
| schema=schema, | |||
| description=extra_info.get('description', ''), | |||
| description=extra_info.get("description", ""), | |||
| schema_type_str=schema_type, | |||
| tools_str=json.dumps(jsonable_encoder(tool_bundles)), | |||
| credentials_str={}, | |||
| privacy_policy=privacy_policy, | |||
| custom_disclaimer=custom_disclaimer | |||
| custom_disclaimer=custom_disclaimer, | |||
| ) | |||
| if 'auth_type' not in credentials: | |||
| raise ValueError('auth_type is required') | |||
| if "auth_type" not in credentials: | |||
| raise ValueError("auth_type is required") | |||
| # get auth type, none or api key | |||
| auth_type = ApiProviderAuthType.value_of(credentials['auth_type']) | |||
| auth_type = ApiProviderAuthType.value_of(credentials["auth_type"]) | |||
| # create provider entity | |||
| provider_controller = ApiToolProviderController.from_db(db_provider, auth_type) | |||
| @@ -172,14 +168,12 @@ class ApiToolManageService: | |||
| # update labels | |||
| ToolLabelManager.update_tool_labels(provider_controller, labels) | |||
| return { 'result': 'success' } | |||
| return {"result": "success"} | |||
| @staticmethod | |||
| def get_api_tool_provider_remote_schema( | |||
| user_id: str, tenant_id: str, url: str | |||
| ): | |||
| def get_api_tool_provider_remote_schema(user_id: str, tenant_id: str, url: str): | |||
| """ | |||
| get api tool provider remote schema | |||
| get api tool provider remote schema | |||
| """ | |||
| headers = { | |||
| "User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36 Edg/120.0.0.0", | |||
| @@ -189,84 +183,98 @@ class ApiToolManageService: | |||
| try: | |||
| response = get(url, headers=headers, timeout=10) | |||
| if response.status_code != 200: | |||
| raise ValueError(f'Got status code {response.status_code}') | |||
| raise ValueError(f"Got status code {response.status_code}") | |||
| schema = response.text | |||
| # try to parse schema, avoid SSRF attack | |||
| ApiToolManageService.parser_api_schema(schema) | |||
| except Exception as e: | |||
| logger.error(f"parse api schema error: {str(e)}") | |||
| raise ValueError('invalid schema, please check the url you provided') | |||
| return { | |||
| 'schema': schema | |||
| } | |||
| raise ValueError("invalid schema, please check the url you provided") | |||
| return {"schema": schema} | |||
| @staticmethod | |||
| def list_api_tool_provider_tools( | |||
| user_id: str, tenant_id: str, provider: str | |||
| ) -> list[UserTool]: | |||
| def list_api_tool_provider_tools(user_id: str, tenant_id: str, provider: str) -> list[UserTool]: | |||
| """ | |||
| list api tool provider tools | |||
| list api tool provider tools | |||
| """ | |||
| provider: ApiToolProvider = db.session.query(ApiToolProvider).filter( | |||
| ApiToolProvider.tenant_id == tenant_id, | |||
| ApiToolProvider.name == provider, | |||
| ).first() | |||
| provider: ApiToolProvider = ( | |||
| db.session.query(ApiToolProvider) | |||
| .filter( | |||
| ApiToolProvider.tenant_id == tenant_id, | |||
| ApiToolProvider.name == provider, | |||
| ) | |||
| .first() | |||
| ) | |||
| if provider is None: | |||
| raise ValueError(f'you have not added provider {provider}') | |||
| raise ValueError(f"you have not added provider {provider}") | |||
| controller = ToolTransformService.api_provider_to_controller(db_provider=provider) | |||
| labels = ToolLabelManager.get_tool_labels(controller) | |||
| return [ | |||
| ToolTransformService.tool_to_user_tool( | |||
| tool_bundle, | |||
| labels=labels, | |||
| ) for tool_bundle in provider.tools | |||
| ) | |||
| for tool_bundle in provider.tools | |||
| ] | |||
| @staticmethod | |||
| def update_api_tool_provider( | |||
| user_id: str, tenant_id: str, provider_name: str, original_provider: str, icon: dict, credentials: dict, | |||
| schema_type: str, schema: str, privacy_policy: str, custom_disclaimer: str, labels: list[str] | |||
| user_id: str, | |||
| tenant_id: str, | |||
| provider_name: str, | |||
| original_provider: str, | |||
| icon: dict, | |||
| credentials: dict, | |||
| schema_type: str, | |||
| schema: str, | |||
| privacy_policy: str, | |||
| custom_disclaimer: str, | |||
| labels: list[str], | |||
| ): | |||
| """ | |||
| update api tool provider | |||
| update api tool provider | |||
| """ | |||
| if schema_type not in [member.value for member in ApiProviderSchemaType]: | |||
| raise ValueError(f'invalid schema type {schema}') | |||
| raise ValueError(f"invalid schema type {schema}") | |||
| # check if the provider exists | |||
| provider: ApiToolProvider = db.session.query(ApiToolProvider).filter( | |||
| ApiToolProvider.tenant_id == tenant_id, | |||
| ApiToolProvider.name == original_provider, | |||
| ).first() | |||
| provider: ApiToolProvider = ( | |||
| db.session.query(ApiToolProvider) | |||
| .filter( | |||
| ApiToolProvider.tenant_id == tenant_id, | |||
| ApiToolProvider.name == original_provider, | |||
| ) | |||
| .first() | |||
| ) | |||
| if provider is None: | |||
| raise ValueError(f'api provider {provider_name} does not exists') | |||
| raise ValueError(f"api provider {provider_name} does not exists") | |||
| # parse openapi to tool bundle | |||
| extra_info = {} | |||
| # extra info like description will be set here | |||
| tool_bundles, schema_type = ApiToolManageService.convert_schema_to_tool_bundles(schema, extra_info) | |||
| # update db provider | |||
| provider.name = provider_name | |||
| provider.icon = json.dumps(icon) | |||
| provider.schema = schema | |||
| provider.description = extra_info.get('description', '') | |||
| provider.description = extra_info.get("description", "") | |||
| provider.schema_type_str = ApiProviderSchemaType.OPENAPI.value | |||
| provider.tools_str = json.dumps(jsonable_encoder(tool_bundles)) | |||
| provider.privacy_policy = privacy_policy | |||
| provider.custom_disclaimer = custom_disclaimer | |||
| if 'auth_type' not in credentials: | |||
| raise ValueError('auth_type is required') | |||
| if "auth_type" not in credentials: | |||
| raise ValueError("auth_type is required") | |||
| # get auth type, none or api key | |||
| auth_type = ApiProviderAuthType.value_of(credentials['auth_type']) | |||
| auth_type = ApiProviderAuthType.value_of(credentials["auth_type"]) | |||
| # create provider entity | |||
| provider_controller = ApiToolProviderController.from_db(provider, auth_type) | |||
| @@ -295,84 +303,91 @@ class ApiToolManageService: | |||
| # update labels | |||
| ToolLabelManager.update_tool_labels(provider_controller, labels) | |||
| return { 'result': 'success' } | |||
| return {"result": "success"} | |||
| @staticmethod | |||
| def delete_api_tool_provider( | |||
| user_id: str, tenant_id: str, provider_name: str | |||
| ): | |||
| def delete_api_tool_provider(user_id: str, tenant_id: str, provider_name: str): | |||
| """ | |||
| delete tool provider | |||
| delete tool provider | |||
| """ | |||
| provider: ApiToolProvider = db.session.query(ApiToolProvider).filter( | |||
| ApiToolProvider.tenant_id == tenant_id, | |||
| ApiToolProvider.name == provider_name, | |||
| ).first() | |||
| provider: ApiToolProvider = ( | |||
| db.session.query(ApiToolProvider) | |||
| .filter( | |||
| ApiToolProvider.tenant_id == tenant_id, | |||
| ApiToolProvider.name == provider_name, | |||
| ) | |||
| .first() | |||
| ) | |||
| if provider is None: | |||
| raise ValueError(f'you have not added provider {provider_name}') | |||
| raise ValueError(f"you have not added provider {provider_name}") | |||
| db.session.delete(provider) | |||
| db.session.commit() | |||
| return { 'result': 'success' } | |||
| return {"result": "success"} | |||
| @staticmethod | |||
| def get_api_tool_provider( | |||
| user_id: str, tenant_id: str, provider: str | |||
| ): | |||
| def get_api_tool_provider(user_id: str, tenant_id: str, provider: str): | |||
| """ | |||
| get api tool provider | |||
| get api tool provider | |||
| """ | |||
| return ToolManager.user_get_api_provider(provider=provider, tenant_id=tenant_id) | |||
| @staticmethod | |||
| def test_api_tool_preview( | |||
| tenant_id: str, | |||
| tenant_id: str, | |||
| provider_name: str, | |||
| tool_name: str, | |||
| credentials: dict, | |||
| parameters: dict, | |||
| schema_type: str, | |||
| schema: str | |||
| tool_name: str, | |||
| credentials: dict, | |||
| parameters: dict, | |||
| schema_type: str, | |||
| schema: str, | |||
| ): | |||
| """ | |||
| test api tool before adding api tool provider | |||
| test api tool before adding api tool provider | |||
| """ | |||
| if schema_type not in [member.value for member in ApiProviderSchemaType]: | |||
| raise ValueError(f'invalid schema type {schema_type}') | |||
| raise ValueError(f"invalid schema type {schema_type}") | |||
| try: | |||
| tool_bundles, _ = ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(schema) | |||
| except Exception as e: | |||
| raise ValueError('invalid schema') | |||
| raise ValueError("invalid schema") | |||
| # get tool bundle | |||
| tool_bundle = next(filter(lambda tb: tb.operation_id == tool_name, tool_bundles), None) | |||
| if tool_bundle is None: | |||
| raise ValueError(f'invalid tool name {tool_name}') | |||
| db_provider: ApiToolProvider = db.session.query(ApiToolProvider).filter( | |||
| ApiToolProvider.tenant_id == tenant_id, | |||
| ApiToolProvider.name == provider_name, | |||
| ).first() | |||
| raise ValueError(f"invalid tool name {tool_name}") | |||
| db_provider: ApiToolProvider = ( | |||
| db.session.query(ApiToolProvider) | |||
| .filter( | |||
| ApiToolProvider.tenant_id == tenant_id, | |||
| ApiToolProvider.name == provider_name, | |||
| ) | |||
| .first() | |||
| ) | |||
| if not db_provider: | |||
| # create a fake db provider | |||
| db_provider = ApiToolProvider( | |||
| tenant_id='', user_id='', name='', icon='', | |||
| tenant_id="", | |||
| user_id="", | |||
| name="", | |||
| icon="", | |||
| schema=schema, | |||
| description='', | |||
| description="", | |||
| schema_type_str=ApiProviderSchemaType.OPENAPI.value, | |||
| tools_str=json.dumps(jsonable_encoder(tool_bundles)), | |||
| credentials_str=json.dumps(credentials), | |||
| ) | |||
| if 'auth_type' not in credentials: | |||
| raise ValueError('auth_type is required') | |||
| if "auth_type" not in credentials: | |||
| raise ValueError("auth_type is required") | |||
| # get auth type, none or api key | |||
| auth_type = ApiProviderAuthType.value_of(credentials['auth_type']) | |||
| auth_type = ApiProviderAuthType.value_of(credentials["auth_type"]) | |||
| # create provider entity | |||
| provider_controller = ApiToolProviderController.from_db(db_provider, auth_type) | |||
| @@ -381,10 +396,7 @@ class ApiToolManageService: | |||
| # decrypt credentials | |||
| if db_provider.id: | |||
| tool_configuration = ToolConfigurationManager( | |||
| tenant_id=tenant_id, | |||
| provider_controller=provider_controller | |||
| ) | |||
| tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller) | |||
| decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials) | |||
| # check if the credential has changed, save the original credential | |||
| masked_credentials = tool_configuration.mask_tool_credentials(decrypted_credentials) | |||
| @@ -396,27 +408,27 @@ class ApiToolManageService: | |||
| provider_controller.validate_credentials_format(credentials) | |||
| # get tool | |||
| tool = provider_controller.get_tool(tool_name) | |||
| tool = tool.fork_tool_runtime(runtime={ | |||
| 'credentials': credentials, | |||
| 'tenant_id': tenant_id, | |||
| }) | |||
| tool = tool.fork_tool_runtime( | |||
| runtime={ | |||
| "credentials": credentials, | |||
| "tenant_id": tenant_id, | |||
| } | |||
| ) | |||
| result = tool.validate_credentials(credentials, parameters) | |||
| except Exception as e: | |||
| return { 'error': str(e) } | |||
| return { 'result': result or 'empty response' } | |||
| return {"error": str(e)} | |||
| return {"result": result or "empty response"} | |||
| @staticmethod | |||
| def list_api_tools( | |||
| user_id: str, tenant_id: str | |||
| ) -> list[UserToolProvider]: | |||
| def list_api_tools(user_id: str, tenant_id: str) -> list[UserToolProvider]: | |||
| """ | |||
| list api tools | |||
| list api tools | |||
| """ | |||
| # get all api providers | |||
| db_providers: list[ApiToolProvider] = db.session.query(ApiToolProvider).filter( | |||
| ApiToolProvider.tenant_id == tenant_id | |||
| ).all() or [] | |||
| db_providers: list[ApiToolProvider] = ( | |||
| db.session.query(ApiToolProvider).filter(ApiToolProvider.tenant_id == tenant_id).all() or [] | |||
| ) | |||
| result: list[UserToolProvider] = [] | |||
| @@ -425,26 +437,21 @@ class ApiToolManageService: | |||
| provider_controller = ToolTransformService.api_provider_to_controller(db_provider=provider) | |||
| labels = ToolLabelManager.get_tool_labels(provider_controller) | |||
| user_provider = ToolTransformService.api_provider_to_user_provider( | |||
| provider_controller, | |||
| db_provider=provider, | |||
| decrypt_credentials=True | |||
| provider_controller, db_provider=provider, decrypt_credentials=True | |||
| ) | |||
| user_provider.labels = labels | |||
| # add icon | |||
| ToolTransformService.repack_provider(user_provider) | |||
| tools = provider_controller.get_tools( | |||
| user_id=user_id, tenant_id=tenant_id | |||
| ) | |||
| tools = provider_controller.get_tools(user_id=user_id, tenant_id=tenant_id) | |||
| for tool in tools: | |||
| user_provider.tools.append(ToolTransformService.tool_to_user_tool( | |||
| tenant_id=tenant_id, | |||
| tool=tool, | |||
| credentials=user_provider.original_credentials, | |||
| labels=labels | |||
| )) | |||
| user_provider.tools.append( | |||
| ToolTransformService.tool_to_user_tool( | |||
| tenant_id=tenant_id, tool=tool, credentials=user_provider.original_credentials, labels=labels | |||
| ) | |||
| ) | |||
| result.append(user_provider) | |||
| @@ -20,21 +20,25 @@ logger = logging.getLogger(__name__) | |||
| class BuiltinToolManageService: | |||
| @staticmethod | |||
| def list_builtin_tool_provider_tools( | |||
| user_id: str, tenant_id: str, provider: str | |||
| ) -> list[UserTool]: | |||
| def list_builtin_tool_provider_tools(user_id: str, tenant_id: str, provider: str) -> list[UserTool]: | |||
| """ | |||
| list builtin tool provider tools | |||
| list builtin tool provider tools | |||
| """ | |||
| provider_controller: ToolProviderController = ToolManager.get_builtin_provider(provider) | |||
| tools = provider_controller.get_tools() | |||
| tool_provider_configurations = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller) | |||
| tool_provider_configurations = ToolConfigurationManager( | |||
| tenant_id=tenant_id, provider_controller=provider_controller | |||
| ) | |||
| # check if user has added the provider | |||
| builtin_provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter( | |||
| BuiltinToolProvider.tenant_id == tenant_id, | |||
| BuiltinToolProvider.provider == provider, | |||
| ).first() | |||
| builtin_provider: BuiltinToolProvider = ( | |||
| db.session.query(BuiltinToolProvider) | |||
| .filter( | |||
| BuiltinToolProvider.tenant_id == tenant_id, | |||
| BuiltinToolProvider.provider == provider, | |||
| ) | |||
| .first() | |||
| ) | |||
| credentials = {} | |||
| if builtin_provider is not None: | |||
| @@ -44,47 +48,47 @@ class BuiltinToolManageService: | |||
| result = [] | |||
| for tool in tools: | |||
| result.append(ToolTransformService.tool_to_user_tool( | |||
| tool=tool, | |||
| credentials=credentials, | |||
| tenant_id=tenant_id, | |||
| labels=ToolLabelManager.get_tool_labels(provider_controller) | |||
| )) | |||
| result.append( | |||
| ToolTransformService.tool_to_user_tool( | |||
| tool=tool, | |||
| credentials=credentials, | |||
| tenant_id=tenant_id, | |||
| labels=ToolLabelManager.get_tool_labels(provider_controller), | |||
| ) | |||
| ) | |||
| return result | |||
| @staticmethod | |||
| def list_builtin_provider_credentials_schema( | |||
| provider_name | |||
| ): | |||
| def list_builtin_provider_credentials_schema(provider_name): | |||
| """ | |||
| list builtin provider credentials schema | |||
| list builtin provider credentials schema | |||
| :return: the list of tool providers | |||
| :return: the list of tool providers | |||
| """ | |||
| provider = ToolManager.get_builtin_provider(provider_name) | |||
| return jsonable_encoder([ | |||
| v for _, v in (provider.credentials_schema or {}).items() | |||
| ]) | |||
| return jsonable_encoder([v for _, v in (provider.credentials_schema or {}).items()]) | |||
| @staticmethod | |||
| def update_builtin_tool_provider( | |||
| user_id: str, tenant_id: str, provider_name: str, credentials: dict | |||
| ): | |||
| def update_builtin_tool_provider(user_id: str, tenant_id: str, provider_name: str, credentials: dict): | |||
| """ | |||
| update builtin tool provider | |||
| update builtin tool provider | |||
| """ | |||
| # get if the provider exists | |||
| provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter( | |||
| BuiltinToolProvider.tenant_id == tenant_id, | |||
| BuiltinToolProvider.provider == provider_name, | |||
| ).first() | |||
| provider: BuiltinToolProvider = ( | |||
| db.session.query(BuiltinToolProvider) | |||
| .filter( | |||
| BuiltinToolProvider.tenant_id == tenant_id, | |||
| BuiltinToolProvider.provider == provider_name, | |||
| ) | |||
| .first() | |||
| ) | |||
| try: | |||
| # get provider | |||
| provider_controller = ToolManager.get_builtin_provider(provider_name) | |||
| if not provider_controller.need_credentials: | |||
| raise ValueError(f'provider {provider_name} does not need credentials') | |||
| raise ValueError(f"provider {provider_name} does not need credentials") | |||
| tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller) | |||
| # get original credentials if exists | |||
| if provider is not None: | |||
| @@ -121,19 +125,21 @@ class BuiltinToolManageService: | |||
| # delete cache | |||
| tool_configuration.delete_tool_credentials_cache() | |||
| return {'result': 'success'} | |||
| return {"result": "success"} | |||
| @staticmethod | |||
| def get_builtin_tool_provider_credentials( | |||
| user_id: str, tenant_id: str, provider: str | |||
| ): | |||
| def get_builtin_tool_provider_credentials(user_id: str, tenant_id: str, provider: str): | |||
| """ | |||
| get builtin tool provider credentials | |||
| get builtin tool provider credentials | |||
| """ | |||
| provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter( | |||
| BuiltinToolProvider.tenant_id == tenant_id, | |||
| BuiltinToolProvider.provider == provider, | |||
| ).first() | |||
| provider: BuiltinToolProvider = ( | |||
| db.session.query(BuiltinToolProvider) | |||
| .filter( | |||
| BuiltinToolProvider.tenant_id == tenant_id, | |||
| BuiltinToolProvider.provider == provider, | |||
| ) | |||
| .first() | |||
| ) | |||
| if provider is None: | |||
| return {} | |||
| @@ -145,19 +151,21 @@ class BuiltinToolManageService: | |||
| return credentials | |||
| @staticmethod | |||
| def delete_builtin_tool_provider( | |||
| user_id: str, tenant_id: str, provider_name: str | |||
| ): | |||
| def delete_builtin_tool_provider(user_id: str, tenant_id: str, provider_name: str): | |||
| """ | |||
| delete tool provider | |||
| delete tool provider | |||
| """ | |||
| provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter( | |||
| BuiltinToolProvider.tenant_id == tenant_id, | |||
| BuiltinToolProvider.provider == provider_name, | |||
| ).first() | |||
| provider: BuiltinToolProvider = ( | |||
| db.session.query(BuiltinToolProvider) | |||
| .filter( | |||
| BuiltinToolProvider.tenant_id == tenant_id, | |||
| BuiltinToolProvider.provider == provider_name, | |||
| ) | |||
| .first() | |||
| ) | |||
| if provider is None: | |||
| raise ValueError(f'you have not added provider {provider_name}') | |||
| raise ValueError(f"you have not added provider {provider_name}") | |||
| db.session.delete(provider) | |||
| db.session.commit() | |||
| @@ -167,38 +175,36 @@ class BuiltinToolManageService: | |||
| tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller) | |||
| tool_configuration.delete_tool_credentials_cache() | |||
| return {'result': 'success'} | |||
| return {"result": "success"} | |||
| @staticmethod | |||
| def get_builtin_tool_provider_icon( | |||
| provider: str | |||
| ): | |||
| def get_builtin_tool_provider_icon(provider: str): | |||
| """ | |||
| get tool provider icon and it's mimetype | |||
| get tool provider icon and it's mimetype | |||
| """ | |||
| icon_path, mime_type = ToolManager.get_builtin_provider_icon(provider) | |||
| with open(icon_path, 'rb') as f: | |||
| with open(icon_path, "rb") as f: | |||
| icon_bytes = f.read() | |||
| return icon_bytes, mime_type | |||
| @staticmethod | |||
| def list_builtin_tools( | |||
| user_id: str, tenant_id: str | |||
| ) -> list[UserToolProvider]: | |||
| def list_builtin_tools(user_id: str, tenant_id: str) -> list[UserToolProvider]: | |||
| """ | |||
| list builtin tools | |||
| list builtin tools | |||
| """ | |||
| # get all builtin providers | |||
| provider_controllers = ToolManager.list_builtin_providers() | |||
| # get all user added providers | |||
| db_providers: list[BuiltinToolProvider] = db.session.query(BuiltinToolProvider).filter( | |||
| BuiltinToolProvider.tenant_id == tenant_id | |||
| ).all() or [] | |||
| db_providers: list[BuiltinToolProvider] = ( | |||
| db.session.query(BuiltinToolProvider).filter(BuiltinToolProvider.tenant_id == tenant_id).all() or [] | |||
| ) | |||
| # find provider | |||
| find_provider = lambda provider: next(filter(lambda db_provider: db_provider.provider == provider, db_providers), None) | |||
| find_provider = lambda provider: next( | |||
| filter(lambda db_provider: db_provider.provider == provider, db_providers), None | |||
| ) | |||
| result: list[UserToolProvider] = [] | |||
| @@ -209,7 +215,7 @@ class BuiltinToolManageService: | |||
| include_set=dify_config.POSITION_TOOL_INCLUDES_SET, | |||
| exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET, | |||
| data=provider_controller, | |||
| name_func=lambda x: x.identity.name | |||
| name_func=lambda x: x.identity.name, | |||
| ): | |||
| continue | |||
| @@ -217,7 +223,7 @@ class BuiltinToolManageService: | |||
| user_builtin_provider = ToolTransformService.builtin_provider_to_user_provider( | |||
| provider_controller=provider_controller, | |||
| db_provider=find_provider(provider_controller.identity.name), | |||
| decrypt_credentials=True | |||
| decrypt_credentials=True, | |||
| ) | |||
| # add icon | |||
| @@ -225,12 +231,14 @@ class BuiltinToolManageService: | |||
| tools = provider_controller.get_tools() | |||
| for tool in tools: | |||
| user_builtin_provider.tools.append(ToolTransformService.tool_to_user_tool( | |||
| tenant_id=tenant_id, | |||
| tool=tool, | |||
| credentials=user_builtin_provider.original_credentials, | |||
| labels=ToolLabelManager.get_tool_labels(provider_controller) | |||
| )) | |||
| user_builtin_provider.tools.append( | |||
| ToolTransformService.tool_to_user_tool( | |||
| tenant_id=tenant_id, | |||
| tool=tool, | |||
| credentials=user_builtin_provider.original_credentials, | |||
| labels=ToolLabelManager.get_tool_labels(provider_controller), | |||
| ) | |||
| ) | |||
| result.append(user_builtin_provider) | |||
| except Exception as e: | |||
| @@ -5,4 +5,4 @@ from core.tools.entities.values import default_tool_labels | |||
| class ToolLabelsService: | |||
| @classmethod | |||
| def list_tool_labels(cls) -> list[ToolLabel]: | |||
| return default_tool_labels | |||
| return default_tool_labels | |||
| @@ -11,13 +11,11 @@ class ToolCommonService: | |||
| @staticmethod | |||
| def list_tool_providers(user_id: str, tenant_id: str, typ: UserToolProviderTypeLiteral = None): | |||
| """ | |||
| list tool providers | |||
| list tool providers | |||
| :return: the list of tool providers | |||
| :return: the list of tool providers | |||
| """ | |||
| providers = ToolManager.user_list_providers( | |||
| user_id, tenant_id, typ | |||
| ) | |||
| providers = ToolManager.user_list_providers(user_id, tenant_id, typ) | |||
| # add icon | |||
| for provider in providers: | |||
| @@ -26,4 +24,3 @@ class ToolCommonService: | |||
| result = [provider.to_dict() for provider in providers] | |||
| return result | |||
| @@ -22,46 +22,39 @@ from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvi | |||
| logger = logging.getLogger(__name__) | |||
| class ToolTransformService: | |||
| @staticmethod | |||
| def get_tool_provider_icon_url(provider_type: str, provider_name: str, icon: str) -> Union[str, dict]: | |||
| """ | |||
| get tool provider icon url | |||
| get tool provider icon url | |||
| """ | |||
| url_prefix = (dify_config.CONSOLE_API_URL | |||
| + "/console/api/workspaces/current/tool-provider/") | |||
| url_prefix = dify_config.CONSOLE_API_URL + "/console/api/workspaces/current/tool-provider/" | |||
| if provider_type == ToolProviderType.BUILT_IN.value: | |||
| return url_prefix + 'builtin/' + provider_name + '/icon' | |||
| return url_prefix + "builtin/" + provider_name + "/icon" | |||
| elif provider_type in [ToolProviderType.API.value, ToolProviderType.WORKFLOW.value]: | |||
| try: | |||
| return json.loads(icon) | |||
| except: | |||
| return { | |||
| "background": "#252525", | |||
| "content": "\ud83d\ude01" | |||
| } | |||
| return '' | |||
| return {"background": "#252525", "content": "\ud83d\ude01"} | |||
| return "" | |||
| @staticmethod | |||
| def repack_provider(provider: Union[dict, UserToolProvider]): | |||
| """ | |||
| repack provider | |||
| repack provider | |||
| :param provider: the provider dict | |||
| :param provider: the provider dict | |||
| """ | |||
| if isinstance(provider, dict) and 'icon' in provider: | |||
| provider['icon'] = ToolTransformService.get_tool_provider_icon_url( | |||
| provider_type=provider['type'], | |||
| provider_name=provider['name'], | |||
| icon=provider['icon'] | |||
| if isinstance(provider, dict) and "icon" in provider: | |||
| provider["icon"] = ToolTransformService.get_tool_provider_icon_url( | |||
| provider_type=provider["type"], provider_name=provider["name"], icon=provider["icon"] | |||
| ) | |||
| elif isinstance(provider, UserToolProvider): | |||
| provider.icon = ToolTransformService.get_tool_provider_icon_url( | |||
| provider_type=provider.type.value, | |||
| provider_name=provider.name, | |||
| icon=provider.icon | |||
| provider_type=provider.type.value, provider_name=provider.name, icon=provider.icon | |||
| ) | |||
| @staticmethod | |||
| @@ -92,14 +85,13 @@ class ToolTransformService: | |||
| masked_credentials={}, | |||
| is_team_authorization=False, | |||
| tools=[], | |||
| labels=provider_controller.tool_labels | |||
| labels=provider_controller.tool_labels, | |||
| ) | |||
| # get credentials schema | |||
| schema = provider_controller.get_credentials_schema() | |||
| for name, value in schema.items(): | |||
| result.masked_credentials[name] = \ | |||
| ToolProviderCredentials.CredentialsType.default(value.type) | |||
| result.masked_credentials[name] = ToolProviderCredentials.CredentialsType.default(value.type) | |||
| # check if the provider need credentials | |||
| if not provider_controller.need_credentials: | |||
| @@ -113,8 +105,7 @@ class ToolTransformService: | |||
| # init tool configuration | |||
| tool_configuration = ToolConfigurationManager( | |||
| tenant_id=db_provider.tenant_id, | |||
| provider_controller=provider_controller | |||
| tenant_id=db_provider.tenant_id, provider_controller=provider_controller | |||
| ) | |||
| # decrypt the credentials and mask the credentials | |||
| decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials=credentials) | |||
| @@ -124,7 +115,7 @@ class ToolTransformService: | |||
| result.original_credentials = decrypted_credentials | |||
| return result | |||
| @staticmethod | |||
| def api_provider_to_controller( | |||
| db_provider: ApiToolProvider, | |||
| @@ -135,25 +126,23 @@ class ToolTransformService: | |||
| # package tool provider controller | |||
| controller = ApiToolProviderController.from_db( | |||
| db_provider=db_provider, | |||
| auth_type=ApiProviderAuthType.API_KEY if db_provider.credentials['auth_type'] == 'api_key' else | |||
| ApiProviderAuthType.NONE | |||
| auth_type=ApiProviderAuthType.API_KEY | |||
| if db_provider.credentials["auth_type"] == "api_key" | |||
| else ApiProviderAuthType.NONE, | |||
| ) | |||
| return controller | |||
| @staticmethod | |||
| def workflow_provider_to_controller( | |||
| db_provider: WorkflowToolProvider | |||
| ) -> WorkflowToolProviderController: | |||
| def workflow_provider_to_controller(db_provider: WorkflowToolProvider) -> WorkflowToolProviderController: | |||
| """ | |||
| convert provider controller to provider | |||
| """ | |||
| return WorkflowToolProviderController.from_db(db_provider) | |||
| @staticmethod | |||
| def workflow_provider_to_user_provider( | |||
| provider_controller: WorkflowToolProviderController, | |||
| labels: list[str] = None | |||
| provider_controller: WorkflowToolProviderController, labels: list[str] = None | |||
| ): | |||
| """ | |||
| convert provider controller to user provider | |||
| @@ -175,7 +164,7 @@ class ToolTransformService: | |||
| masked_credentials={}, | |||
| is_team_authorization=True, | |||
| tools=[], | |||
| labels=labels or [] | |||
| labels=labels or [], | |||
| ) | |||
| @staticmethod | |||
| @@ -183,16 +172,16 @@ class ToolTransformService: | |||
| provider_controller: ApiToolProviderController, | |||
| db_provider: ApiToolProvider, | |||
| decrypt_credentials: bool = True, | |||
| labels: list[str] = None | |||
| labels: list[str] = None, | |||
| ) -> UserToolProvider: | |||
| """ | |||
| convert provider controller to user provider | |||
| """ | |||
| username = 'Anonymous' | |||
| username = "Anonymous" | |||
| try: | |||
| username = db_provider.user.name | |||
| except Exception as e: | |||
| logger.error(f'failed to get user name for api provider {db_provider.id}: {str(e)}') | |||
| logger.error(f"failed to get user name for api provider {db_provider.id}: {str(e)}") | |||
| # add provider into providers | |||
| credentials = db_provider.credentials | |||
| result = UserToolProvider( | |||
| @@ -212,14 +201,13 @@ class ToolTransformService: | |||
| masked_credentials={}, | |||
| is_team_authorization=True, | |||
| tools=[], | |||
| labels=labels or [] | |||
| labels=labels or [], | |||
| ) | |||
| if decrypt_credentials: | |||
| # init tool configuration | |||
| tool_configuration = ToolConfigurationManager( | |||
| tenant_id=db_provider.tenant_id, | |||
| provider_controller=provider_controller | |||
| tenant_id=db_provider.tenant_id, provider_controller=provider_controller | |||
| ) | |||
| # decrypt the credentials and mask the credentials | |||
| @@ -229,23 +217,25 @@ class ToolTransformService: | |||
| result.masked_credentials = masked_credentials | |||
| return result | |||
| @staticmethod | |||
| def tool_to_user_tool( | |||
| tool: Union[ApiToolBundle, WorkflowTool, Tool], | |||
| credentials: dict = None, | |||
| tool: Union[ApiToolBundle, WorkflowTool, Tool], | |||
| credentials: dict = None, | |||
| tenant_id: str = None, | |||
| labels: list[str] = None | |||
| labels: list[str] = None, | |||
| ) -> UserTool: | |||
| """ | |||
| convert tool to user tool | |||
| """ | |||
| if isinstance(tool, Tool): | |||
| # fork tool runtime | |||
| tool = tool.fork_tool_runtime(runtime={ | |||
| 'credentials': credentials, | |||
| 'tenant_id': tenant_id, | |||
| }) | |||
| tool = tool.fork_tool_runtime( | |||
| runtime={ | |||
| "credentials": credentials, | |||
| "tenant_id": tenant_id, | |||
| } | |||
| ) | |||
| # get tool parameters | |||
| parameters = tool.parameters or [] | |||
| @@ -270,20 +260,14 @@ class ToolTransformService: | |||
| label=tool.identity.label, | |||
| description=tool.description.human, | |||
| parameters=current_parameters, | |||
| labels=labels | |||
| labels=labels, | |||
| ) | |||
| if isinstance(tool, ApiToolBundle): | |||
| return UserTool( | |||
| author=tool.author, | |||
| name=tool.operation_id, | |||
| label=I18nObject( | |||
| en_US=tool.operation_id, | |||
| zh_Hans=tool.operation_id | |||
| ), | |||
| description=I18nObject( | |||
| en_US=tool.summary or '', | |||
| zh_Hans=tool.summary or '' | |||
| ), | |||
| label=I18nObject(en_US=tool.operation_id, zh_Hans=tool.operation_id), | |||
| description=I18nObject(en_US=tool.summary or "", zh_Hans=tool.summary or ""), | |||
| parameters=tool.parameters, | |||
| labels=labels | |||
| ) | |||
| labels=labels, | |||
| ) | |||
| @@ -19,10 +19,21 @@ class WorkflowToolManageService: | |||
| """ | |||
| Service class for managing workflow tools. | |||
| """ | |||
| @classmethod | |||
| def create_workflow_tool(cls, user_id: str, tenant_id: str, workflow_app_id: str, name: str, | |||
| label: str, icon: dict, description: str, | |||
| parameters: list[dict], privacy_policy: str = '', labels: list[str] = None) -> dict: | |||
| def create_workflow_tool( | |||
| cls, | |||
| user_id: str, | |||
| tenant_id: str, | |||
| workflow_app_id: str, | |||
| name: str, | |||
| label: str, | |||
| icon: dict, | |||
| description: str, | |||
| parameters: list[dict], | |||
| privacy_policy: str = "", | |||
| labels: list[str] = None, | |||
| ) -> dict: | |||
| """ | |||
| Create a workflow tool. | |||
| :param user_id: the user id | |||
| @@ -38,27 +49,28 @@ class WorkflowToolManageService: | |||
| WorkflowToolConfigurationUtils.check_parameter_configurations(parameters) | |||
| # check if the name is unique | |||
| existing_workflow_tool_provider = db.session.query(WorkflowToolProvider).filter( | |||
| WorkflowToolProvider.tenant_id == tenant_id, | |||
| # name or app_id | |||
| or_(WorkflowToolProvider.name == name, WorkflowToolProvider.app_id == workflow_app_id) | |||
| ).first() | |||
| existing_workflow_tool_provider = ( | |||
| db.session.query(WorkflowToolProvider) | |||
| .filter( | |||
| WorkflowToolProvider.tenant_id == tenant_id, | |||
| # name or app_id | |||
| or_(WorkflowToolProvider.name == name, WorkflowToolProvider.app_id == workflow_app_id), | |||
| ) | |||
| .first() | |||
| ) | |||
| if existing_workflow_tool_provider is not None: | |||
| raise ValueError(f'Tool with name {name} or app_id {workflow_app_id} already exists') | |||
| app: App = db.session.query(App).filter( | |||
| App.id == workflow_app_id, | |||
| App.tenant_id == tenant_id | |||
| ).first() | |||
| raise ValueError(f"Tool with name {name} or app_id {workflow_app_id} already exists") | |||
| app: App = db.session.query(App).filter(App.id == workflow_app_id, App.tenant_id == tenant_id).first() | |||
| if app is None: | |||
| raise ValueError(f'App {workflow_app_id} not found') | |||
| raise ValueError(f"App {workflow_app_id} not found") | |||
| workflow: Workflow = app.workflow | |||
| if workflow is None: | |||
| raise ValueError(f'Workflow not found for app {workflow_app_id}') | |||
| raise ValueError(f"Workflow not found for app {workflow_app_id}") | |||
| workflow_tool_provider = WorkflowToolProvider( | |||
| tenant_id=tenant_id, | |||
| user_id=user_id, | |||
| @@ -76,19 +88,26 @@ class WorkflowToolManageService: | |||
| WorkflowToolProviderController.from_db(workflow_tool_provider) | |||
| except Exception as e: | |||
| raise ValueError(str(e)) | |||
| db.session.add(workflow_tool_provider) | |||
| db.session.commit() | |||
| return { | |||
| 'result': 'success' | |||
| } | |||
| return {"result": "success"} | |||
| @classmethod | |||
| def update_workflow_tool(cls, user_id: str, tenant_id: str, workflow_tool_id: str, | |||
| name: str, label: str, icon: dict, description: str, | |||
| parameters: list[dict], privacy_policy: str = '', labels: list[str] = None) -> dict: | |||
| def update_workflow_tool( | |||
| cls, | |||
| user_id: str, | |||
| tenant_id: str, | |||
| workflow_tool_id: str, | |||
| name: str, | |||
| label: str, | |||
| icon: dict, | |||
| description: str, | |||
| parameters: list[dict], | |||
| privacy_policy: str = "", | |||
| labels: list[str] = None, | |||
| ) -> dict: | |||
| """ | |||
| Update a workflow tool. | |||
| :param user_id: the user id | |||
| @@ -106,35 +125,39 @@ class WorkflowToolManageService: | |||
| WorkflowToolConfigurationUtils.check_parameter_configurations(parameters) | |||
| # check if the name is unique | |||
| existing_workflow_tool_provider = db.session.query(WorkflowToolProvider).filter( | |||
| WorkflowToolProvider.tenant_id == tenant_id, | |||
| WorkflowToolProvider.name == name, | |||
| WorkflowToolProvider.id != workflow_tool_id | |||
| ).first() | |||
| existing_workflow_tool_provider = ( | |||
| db.session.query(WorkflowToolProvider) | |||
| .filter( | |||
| WorkflowToolProvider.tenant_id == tenant_id, | |||
| WorkflowToolProvider.name == name, | |||
| WorkflowToolProvider.id != workflow_tool_id, | |||
| ) | |||
| .first() | |||
| ) | |||
| if existing_workflow_tool_provider is not None: | |||
| raise ValueError(f'Tool with name {name} already exists') | |||
| workflow_tool_provider: WorkflowToolProvider = db.session.query(WorkflowToolProvider).filter( | |||
| WorkflowToolProvider.tenant_id == tenant_id, | |||
| WorkflowToolProvider.id == workflow_tool_id | |||
| ).first() | |||
| raise ValueError(f"Tool with name {name} already exists") | |||
| workflow_tool_provider: WorkflowToolProvider = ( | |||
| db.session.query(WorkflowToolProvider) | |||
| .filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id) | |||
| .first() | |||
| ) | |||
| if workflow_tool_provider is None: | |||
| raise ValueError(f'Tool {workflow_tool_id} not found') | |||
| app: App = db.session.query(App).filter( | |||
| App.id == workflow_tool_provider.app_id, | |||
| App.tenant_id == tenant_id | |||
| ).first() | |||
| raise ValueError(f"Tool {workflow_tool_id} not found") | |||
| app: App = ( | |||
| db.session.query(App).filter(App.id == workflow_tool_provider.app_id, App.tenant_id == tenant_id).first() | |||
| ) | |||
| if app is None: | |||
| raise ValueError(f'App {workflow_tool_provider.app_id} not found') | |||
| raise ValueError(f"App {workflow_tool_provider.app_id} not found") | |||
| workflow: Workflow = app.workflow | |||
| if workflow is None: | |||
| raise ValueError(f'Workflow not found for app {workflow_tool_provider.app_id}') | |||
| raise ValueError(f"Workflow not found for app {workflow_tool_provider.app_id}") | |||
| workflow_tool_provider.name = name | |||
| workflow_tool_provider.label = label | |||
| workflow_tool_provider.icon = json.dumps(icon) | |||
| @@ -154,13 +177,10 @@ class WorkflowToolManageService: | |||
| if labels is not None: | |||
| ToolLabelManager.update_tool_labels( | |||
| ToolTransformService.workflow_provider_to_controller(workflow_tool_provider), | |||
| labels | |||
| ToolTransformService.workflow_provider_to_controller(workflow_tool_provider), labels | |||
| ) | |||
| return { | |||
| 'result': 'success' | |||
| } | |||
| return {"result": "success"} | |||
| @classmethod | |||
| def list_tenant_workflow_tools(cls, user_id: str, tenant_id: str) -> list[UserToolProvider]: | |||
| @@ -170,9 +190,7 @@ class WorkflowToolManageService: | |||
| :param tenant_id: the tenant id | |||
| :return: the list of tools | |||
| """ | |||
| db_tools = db.session.query(WorkflowToolProvider).filter( | |||
| WorkflowToolProvider.tenant_id == tenant_id | |||
| ).all() | |||
| db_tools = db.session.query(WorkflowToolProvider).filter(WorkflowToolProvider.tenant_id == tenant_id).all() | |||
| tools = [] | |||
| for provider in db_tools: | |||
| @@ -188,14 +206,12 @@ class WorkflowToolManageService: | |||
| for tool in tools: | |||
| user_tool_provider = ToolTransformService.workflow_provider_to_user_provider( | |||
| provider_controller=tool, | |||
| labels=labels.get(tool.provider_id, []) | |||
| provider_controller=tool, labels=labels.get(tool.provider_id, []) | |||
| ) | |||
| ToolTransformService.repack_provider(user_tool_provider) | |||
| user_tool_provider.tools = [ | |||
| ToolTransformService.tool_to_user_tool( | |||
| tool.get_tools(user_id, tenant_id)[0], | |||
| labels=labels.get(tool.provider_id, []) | |||
| tool.get_tools(user_id, tenant_id)[0], labels=labels.get(tool.provider_id, []) | |||
| ) | |||
| ] | |||
| result.append(user_tool_provider) | |||
| @@ -211,15 +227,12 @@ class WorkflowToolManageService: | |||
| :param workflow_app_id: the workflow app id | |||
| """ | |||
| db.session.query(WorkflowToolProvider).filter( | |||
| WorkflowToolProvider.tenant_id == tenant_id, | |||
| WorkflowToolProvider.id == workflow_tool_id | |||
| WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id | |||
| ).delete() | |||
| db.session.commit() | |||
| return { | |||
| 'result': 'success' | |||
| } | |||
| return {"result": "success"} | |||
| @classmethod | |||
| def get_workflow_tool_by_tool_id(cls, user_id: str, tenant_id: str, workflow_tool_id: str) -> dict: | |||
| @@ -230,40 +243,37 @@ class WorkflowToolManageService: | |||
| :param workflow_app_id: the workflow app id | |||
| :return: the tool | |||
| """ | |||
| db_tool: WorkflowToolProvider = db.session.query(WorkflowToolProvider).filter( | |||
| WorkflowToolProvider.tenant_id == tenant_id, | |||
| WorkflowToolProvider.id == workflow_tool_id | |||
| ).first() | |||
| db_tool: WorkflowToolProvider = ( | |||
| db.session.query(WorkflowToolProvider) | |||
| .filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id) | |||
| .first() | |||
| ) | |||
| if db_tool is None: | |||
| raise ValueError(f'Tool {workflow_tool_id} not found') | |||
| workflow_app: App = db.session.query(App).filter( | |||
| App.id == db_tool.app_id, | |||
| App.tenant_id == tenant_id | |||
| ).first() | |||
| raise ValueError(f"Tool {workflow_tool_id} not found") | |||
| workflow_app: App = db.session.query(App).filter(App.id == db_tool.app_id, App.tenant_id == tenant_id).first() | |||
| if workflow_app is None: | |||
| raise ValueError(f'App {db_tool.app_id} not found') | |||
| raise ValueError(f"App {db_tool.app_id} not found") | |||
| tool = ToolTransformService.workflow_provider_to_controller(db_tool) | |||
| return { | |||
| 'name': db_tool.name, | |||
| 'label': db_tool.label, | |||
| 'workflow_tool_id': db_tool.id, | |||
| 'workflow_app_id': db_tool.app_id, | |||
| 'icon': json.loads(db_tool.icon), | |||
| 'description': db_tool.description, | |||
| 'parameters': jsonable_encoder(db_tool.parameter_configurations), | |||
| 'tool': ToolTransformService.tool_to_user_tool( | |||
| tool.get_tools(user_id, tenant_id)[0], | |||
| labels=ToolLabelManager.get_tool_labels(tool) | |||
| "name": db_tool.name, | |||
| "label": db_tool.label, | |||
| "workflow_tool_id": db_tool.id, | |||
| "workflow_app_id": db_tool.app_id, | |||
| "icon": json.loads(db_tool.icon), | |||
| "description": db_tool.description, | |||
| "parameters": jsonable_encoder(db_tool.parameter_configurations), | |||
| "tool": ToolTransformService.tool_to_user_tool( | |||
| tool.get_tools(user_id, tenant_id)[0], labels=ToolLabelManager.get_tool_labels(tool) | |||
| ), | |||
| 'synced': workflow_app.workflow.version == db_tool.version, | |||
| 'privacy_policy': db_tool.privacy_policy, | |||
| "synced": workflow_app.workflow.version == db_tool.version, | |||
| "privacy_policy": db_tool.privacy_policy, | |||
| } | |||
| @classmethod | |||
| def get_workflow_tool_by_app_id(cls, user_id: str, tenant_id: str, workflow_app_id: str) -> dict: | |||
| """ | |||
| @@ -273,40 +283,37 @@ class WorkflowToolManageService: | |||
| :param workflow_app_id: the workflow app id | |||
| :return: the tool | |||
| """ | |||
| db_tool: WorkflowToolProvider = db.session.query(WorkflowToolProvider).filter( | |||
| WorkflowToolProvider.tenant_id == tenant_id, | |||
| WorkflowToolProvider.app_id == workflow_app_id | |||
| ).first() | |||
| db_tool: WorkflowToolProvider = ( | |||
| db.session.query(WorkflowToolProvider) | |||
| .filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.app_id == workflow_app_id) | |||
| .first() | |||
| ) | |||
| if db_tool is None: | |||
| raise ValueError(f'Tool {workflow_app_id} not found') | |||
| workflow_app: App = db.session.query(App).filter( | |||
| App.id == db_tool.app_id, | |||
| App.tenant_id == tenant_id | |||
| ).first() | |||
| raise ValueError(f"Tool {workflow_app_id} not found") | |||
| workflow_app: App = db.session.query(App).filter(App.id == db_tool.app_id, App.tenant_id == tenant_id).first() | |||
| if workflow_app is None: | |||
| raise ValueError(f'App {db_tool.app_id} not found') | |||
| raise ValueError(f"App {db_tool.app_id} not found") | |||
| tool = ToolTransformService.workflow_provider_to_controller(db_tool) | |||
| return { | |||
| 'name': db_tool.name, | |||
| 'label': db_tool.label, | |||
| 'workflow_tool_id': db_tool.id, | |||
| 'workflow_app_id': db_tool.app_id, | |||
| 'icon': json.loads(db_tool.icon), | |||
| 'description': db_tool.description, | |||
| 'parameters': jsonable_encoder(db_tool.parameter_configurations), | |||
| 'tool': ToolTransformService.tool_to_user_tool( | |||
| tool.get_tools(user_id, tenant_id)[0], | |||
| labels=ToolLabelManager.get_tool_labels(tool) | |||
| "name": db_tool.name, | |||
| "label": db_tool.label, | |||
| "workflow_tool_id": db_tool.id, | |||
| "workflow_app_id": db_tool.app_id, | |||
| "icon": json.loads(db_tool.icon), | |||
| "description": db_tool.description, | |||
| "parameters": jsonable_encoder(db_tool.parameter_configurations), | |||
| "tool": ToolTransformService.tool_to_user_tool( | |||
| tool.get_tools(user_id, tenant_id)[0], labels=ToolLabelManager.get_tool_labels(tool) | |||
| ), | |||
| 'synced': workflow_app.workflow.version == db_tool.version, | |||
| 'privacy_policy': db_tool.privacy_policy | |||
| "synced": workflow_app.workflow.version == db_tool.version, | |||
| "privacy_policy": db_tool.privacy_policy, | |||
| } | |||
| @classmethod | |||
| def list_single_workflow_tools(cls, user_id: str, tenant_id: str, workflow_tool_id: str) -> list[dict]: | |||
| """ | |||
| @@ -316,19 +323,19 @@ class WorkflowToolManageService: | |||
| :param workflow_app_id: the workflow app id | |||
| :return: the list of tools | |||
| """ | |||
| db_tool: WorkflowToolProvider = db.session.query(WorkflowToolProvider).filter( | |||
| WorkflowToolProvider.tenant_id == tenant_id, | |||
| WorkflowToolProvider.id == workflow_tool_id | |||
| ).first() | |||
| db_tool: WorkflowToolProvider = ( | |||
| db.session.query(WorkflowToolProvider) | |||
| .filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id) | |||
| .first() | |||
| ) | |||
| if db_tool is None: | |||
| raise ValueError(f'Tool {workflow_tool_id} not found') | |||
| raise ValueError(f"Tool {workflow_tool_id} not found") | |||
| tool = ToolTransformService.workflow_provider_to_controller(db_tool) | |||
| return [ | |||
| ToolTransformService.tool_to_user_tool( | |||
| tool.get_tools(user_id, tenant_id)[0], | |||
| labels=ToolLabelManager.get_tool_labels(tool) | |||
| tool.get_tools(user_id, tenant_id)[0], labels=ToolLabelManager.get_tool_labels(tool) | |||
| ) | |||
| ] | |||
| ] | |||
| @@ -7,10 +7,10 @@ from models.dataset import Dataset, DocumentSegment | |||
| class VectorService: | |||
| @classmethod | |||
| def create_segments_vector(cls, keywords_list: Optional[list[list[str]]], | |||
| segments: list[DocumentSegment], dataset: Dataset): | |||
| def create_segments_vector( | |||
| cls, keywords_list: Optional[list[list[str]]], segments: list[DocumentSegment], dataset: Dataset | |||
| ): | |||
| documents = [] | |||
| for segment in segments: | |||
| document = Document( | |||
| @@ -20,14 +20,12 @@ class VectorService: | |||
| "doc_hash": segment.index_node_hash, | |||
| "document_id": segment.document_id, | |||
| "dataset_id": segment.dataset_id, | |||
| } | |||
| }, | |||
| ) | |||
| documents.append(document) | |||
| if dataset.indexing_technique == 'high_quality': | |||
| if dataset.indexing_technique == "high_quality": | |||
| # save vector index | |||
| vector = Vector( | |||
| dataset=dataset | |||
| ) | |||
| vector = Vector(dataset=dataset) | |||
| vector.add_texts(documents, duplicate_check=True) | |||
| # save keyword index | |||
| @@ -50,13 +48,11 @@ class VectorService: | |||
| "doc_hash": segment.index_node_hash, | |||
| "document_id": segment.document_id, | |||
| "dataset_id": segment.dataset_id, | |||
| } | |||
| }, | |||
| ) | |||
| if dataset.indexing_technique == 'high_quality': | |||
| if dataset.indexing_technique == "high_quality": | |||
| # update vector index | |||
| vector = Vector( | |||
| dataset=dataset | |||
| ) | |||
| vector = Vector(dataset=dataset) | |||
| vector.delete_by_ids([segment.index_node_id]) | |||
| vector.add_texts([document], duplicate_check=True) | |||
| @@ -11,18 +11,29 @@ from services.conversation_service import ConversationService | |||
| class WebConversationService: | |||
| @classmethod | |||
| def pagination_by_last_id(cls, app_model: App, user: Optional[Union[Account, EndUser]], | |||
| last_id: Optional[str], limit: int, invoke_from: InvokeFrom, | |||
| pinned: Optional[bool] = None, | |||
| sort_by='-updated_at') -> InfiniteScrollPagination: | |||
| def pagination_by_last_id( | |||
| cls, | |||
| app_model: App, | |||
| user: Optional[Union[Account, EndUser]], | |||
| last_id: Optional[str], | |||
| limit: int, | |||
| invoke_from: InvokeFrom, | |||
| pinned: Optional[bool] = None, | |||
| sort_by="-updated_at", | |||
| ) -> InfiniteScrollPagination: | |||
| include_ids = None | |||
| exclude_ids = None | |||
| if pinned is not None: | |||
| pinned_conversations = db.session.query(PinnedConversation).filter( | |||
| PinnedConversation.app_id == app_model.id, | |||
| PinnedConversation.created_by_role == ('account' if isinstance(user, Account) else 'end_user'), | |||
| PinnedConversation.created_by == user.id | |||
| ).order_by(PinnedConversation.created_at.desc()).all() | |||
| pinned_conversations = ( | |||
| db.session.query(PinnedConversation) | |||
| .filter( | |||
| PinnedConversation.app_id == app_model.id, | |||
| PinnedConversation.created_by_role == ("account" if isinstance(user, Account) else "end_user"), | |||
| PinnedConversation.created_by == user.id, | |||
| ) | |||
| .order_by(PinnedConversation.created_at.desc()) | |||
| .all() | |||
| ) | |||
| pinned_conversation_ids = [pc.conversation_id for pc in pinned_conversations] | |||
| if pinned: | |||
| include_ids = pinned_conversation_ids | |||
| @@ -37,32 +48,34 @@ class WebConversationService: | |||
| invoke_from=invoke_from, | |||
| include_ids=include_ids, | |||
| exclude_ids=exclude_ids, | |||
| sort_by=sort_by | |||
| sort_by=sort_by, | |||
| ) | |||
| @classmethod | |||
| def pin(cls, app_model: App, conversation_id: str, user: Optional[Union[Account, EndUser]]): | |||
| pinned_conversation = db.session.query(PinnedConversation).filter( | |||
| PinnedConversation.app_id == app_model.id, | |||
| PinnedConversation.conversation_id == conversation_id, | |||
| PinnedConversation.created_by_role == ('account' if isinstance(user, Account) else 'end_user'), | |||
| PinnedConversation.created_by == user.id | |||
| ).first() | |||
| pinned_conversation = ( | |||
| db.session.query(PinnedConversation) | |||
| .filter( | |||
| PinnedConversation.app_id == app_model.id, | |||
| PinnedConversation.conversation_id == conversation_id, | |||
| PinnedConversation.created_by_role == ("account" if isinstance(user, Account) else "end_user"), | |||
| PinnedConversation.created_by == user.id, | |||
| ) | |||
| .first() | |||
| ) | |||
| if pinned_conversation: | |||
| return | |||
| conversation = ConversationService.get_conversation( | |||
| app_model=app_model, | |||
| conversation_id=conversation_id, | |||
| user=user | |||
| app_model=app_model, conversation_id=conversation_id, user=user | |||
| ) | |||
| pinned_conversation = PinnedConversation( | |||
| app_id=app_model.id, | |||
| conversation_id=conversation.id, | |||
| created_by_role='account' if isinstance(user, Account) else 'end_user', | |||
| created_by=user.id | |||
| created_by_role="account" if isinstance(user, Account) else "end_user", | |||
| created_by=user.id, | |||
| ) | |||
| db.session.add(pinned_conversation) | |||
| @@ -70,12 +83,16 @@ class WebConversationService: | |||
| @classmethod | |||
| def unpin(cls, app_model: App, conversation_id: str, user: Optional[Union[Account, EndUser]]): | |||
| pinned_conversation = db.session.query(PinnedConversation).filter( | |||
| PinnedConversation.app_id == app_model.id, | |||
| PinnedConversation.conversation_id == conversation_id, | |||
| PinnedConversation.created_by_role == ('account' if isinstance(user, Account) else 'end_user'), | |||
| PinnedConversation.created_by == user.id | |||
| ).first() | |||
| pinned_conversation = ( | |||
| db.session.query(PinnedConversation) | |||
| .filter( | |||
| PinnedConversation.app_id == app_model.id, | |||
| PinnedConversation.conversation_id == conversation_id, | |||
| PinnedConversation.created_by_role == ("account" if isinstance(user, Account) else "end_user"), | |||
| PinnedConversation.created_by == user.id, | |||
| ) | |||
| .first() | |||
| ) | |||
| if not pinned_conversation: | |||
| return | |||
| @@ -11,161 +11,126 @@ from services.auth.api_key_auth_service import ApiKeyAuthService | |||
| class WebsiteService: | |||
| @classmethod | |||
| def document_create_args_validate(cls, args: dict): | |||
| if 'url' not in args or not args['url']: | |||
| raise ValueError('url is required') | |||
| if 'options' not in args or not args['options']: | |||
| raise ValueError('options is required') | |||
| if 'limit' not in args['options'] or not args['options']['limit']: | |||
| raise ValueError('limit is required') | |||
| if "url" not in args or not args["url"]: | |||
| raise ValueError("url is required") | |||
| if "options" not in args or not args["options"]: | |||
| raise ValueError("options is required") | |||
| if "limit" not in args["options"] or not args["options"]["limit"]: | |||
| raise ValueError("limit is required") | |||
| @classmethod | |||
| def crawl_url(cls, args: dict) -> dict: | |||
| provider = args.get('provider') | |||
| url = args.get('url') | |||
| options = args.get('options') | |||
| credentials = ApiKeyAuthService.get_auth_credentials(current_user.current_tenant_id, | |||
| 'website', | |||
| provider) | |||
| if provider == 'firecrawl': | |||
| provider = args.get("provider") | |||
| url = args.get("url") | |||
| options = args.get("options") | |||
| credentials = ApiKeyAuthService.get_auth_credentials(current_user.current_tenant_id, "website", provider) | |||
| if provider == "firecrawl": | |||
| # decrypt api_key | |||
| api_key = encrypter.decrypt_token( | |||
| tenant_id=current_user.current_tenant_id, | |||
| token=credentials.get('config').get('api_key') | |||
| tenant_id=current_user.current_tenant_id, token=credentials.get("config").get("api_key") | |||
| ) | |||
| firecrawl_app = FirecrawlApp(api_key=api_key, | |||
| base_url=credentials.get('config').get('base_url', None)) | |||
| crawl_sub_pages = options.get('crawl_sub_pages', False) | |||
| only_main_content = options.get('only_main_content', False) | |||
| firecrawl_app = FirecrawlApp(api_key=api_key, base_url=credentials.get("config").get("base_url", None)) | |||
| crawl_sub_pages = options.get("crawl_sub_pages", False) | |||
| only_main_content = options.get("only_main_content", False) | |||
| if not crawl_sub_pages: | |||
| params = { | |||
| 'crawlerOptions': { | |||
| "crawlerOptions": { | |||
| "includes": [], | |||
| "excludes": [], | |||
| "generateImgAltText": True, | |||
| "limit": 1, | |||
| 'returnOnlyUrls': False, | |||
| 'pageOptions': { | |||
| 'onlyMainContent': only_main_content, | |||
| "includeHtml": False | |||
| } | |||
| "returnOnlyUrls": False, | |||
| "pageOptions": {"onlyMainContent": only_main_content, "includeHtml": False}, | |||
| } | |||
| } | |||
| else: | |||
| includes = options.get('includes').split(',') if options.get('includes') else [] | |||
| excludes = options.get('excludes').split(',') if options.get('excludes') else [] | |||
| includes = options.get("includes").split(",") if options.get("includes") else [] | |||
| excludes = options.get("excludes").split(",") if options.get("excludes") else [] | |||
| params = { | |||
| 'crawlerOptions': { | |||
| "crawlerOptions": { | |||
| "includes": includes if includes else [], | |||
| "excludes": excludes if excludes else [], | |||
| "generateImgAltText": True, | |||
| "limit": options.get('limit', 1), | |||
| 'returnOnlyUrls': False, | |||
| 'pageOptions': { | |||
| 'onlyMainContent': only_main_content, | |||
| "includeHtml": False | |||
| } | |||
| "limit": options.get("limit", 1), | |||
| "returnOnlyUrls": False, | |||
| "pageOptions": {"onlyMainContent": only_main_content, "includeHtml": False}, | |||
| } | |||
| } | |||
| if options.get('max_depth'): | |||
| params['crawlerOptions']['maxDepth'] = options.get('max_depth') | |||
| if options.get("max_depth"): | |||
| params["crawlerOptions"]["maxDepth"] = options.get("max_depth") | |||
| job_id = firecrawl_app.crawl_url(url, params) | |||
| website_crawl_time_cache_key = f'website_crawl_{job_id}' | |||
| website_crawl_time_cache_key = f"website_crawl_{job_id}" | |||
| time = str(datetime.datetime.now().timestamp()) | |||
| redis_client.setex(website_crawl_time_cache_key, 3600, time) | |||
| return { | |||
| 'status': 'active', | |||
| 'job_id': job_id | |||
| } | |||
| return {"status": "active", "job_id": job_id} | |||
| else: | |||
| raise ValueError('Invalid provider') | |||
| raise ValueError("Invalid provider") | |||
| @classmethod | |||
| def get_crawl_status(cls, job_id: str, provider: str) -> dict: | |||
| credentials = ApiKeyAuthService.get_auth_credentials(current_user.current_tenant_id, | |||
| 'website', | |||
| provider) | |||
| if provider == 'firecrawl': | |||
| credentials = ApiKeyAuthService.get_auth_credentials(current_user.current_tenant_id, "website", provider) | |||
| if provider == "firecrawl": | |||
| # decrypt api_key | |||
| api_key = encrypter.decrypt_token( | |||
| tenant_id=current_user.current_tenant_id, | |||
| token=credentials.get('config').get('api_key') | |||
| tenant_id=current_user.current_tenant_id, token=credentials.get("config").get("api_key") | |||
| ) | |||
| firecrawl_app = FirecrawlApp(api_key=api_key, | |||
| base_url=credentials.get('config').get('base_url', None)) | |||
| firecrawl_app = FirecrawlApp(api_key=api_key, base_url=credentials.get("config").get("base_url", None)) | |||
| result = firecrawl_app.check_crawl_status(job_id) | |||
| crawl_status_data = { | |||
| 'status': result.get('status', 'active'), | |||
| 'job_id': job_id, | |||
| 'total': result.get('total', 0), | |||
| 'current': result.get('current', 0), | |||
| 'data': result.get('data', []) | |||
| "status": result.get("status", "active"), | |||
| "job_id": job_id, | |||
| "total": result.get("total", 0), | |||
| "current": result.get("current", 0), | |||
| "data": result.get("data", []), | |||
| } | |||
| if crawl_status_data['status'] == 'completed': | |||
| website_crawl_time_cache_key = f'website_crawl_{job_id}' | |||
| if crawl_status_data["status"] == "completed": | |||
| website_crawl_time_cache_key = f"website_crawl_{job_id}" | |||
| start_time = redis_client.get(website_crawl_time_cache_key) | |||
| if start_time: | |||
| end_time = datetime.datetime.now().timestamp() | |||
| time_consuming = abs(end_time - float(start_time)) | |||
| crawl_status_data['time_consuming'] = f"{time_consuming:.2f}" | |||
| crawl_status_data["time_consuming"] = f"{time_consuming:.2f}" | |||
| redis_client.delete(website_crawl_time_cache_key) | |||
| else: | |||
| raise ValueError('Invalid provider') | |||
| raise ValueError("Invalid provider") | |||
| return crawl_status_data | |||
| @classmethod | |||
| def get_crawl_url_data(cls, job_id: str, provider: str, url: str, tenant_id: str) -> dict | None: | |||
| credentials = ApiKeyAuthService.get_auth_credentials(tenant_id, | |||
| 'website', | |||
| provider) | |||
| if provider == 'firecrawl': | |||
| file_key = 'website_files/' + job_id + '.txt' | |||
| credentials = ApiKeyAuthService.get_auth_credentials(tenant_id, "website", provider) | |||
| if provider == "firecrawl": | |||
| file_key = "website_files/" + job_id + ".txt" | |||
| if storage.exists(file_key): | |||
| data = storage.load_once(file_key) | |||
| if data: | |||
| data = json.loads(data.decode('utf-8')) | |||
| data = json.loads(data.decode("utf-8")) | |||
| else: | |||
| # decrypt api_key | |||
| api_key = encrypter.decrypt_token( | |||
| tenant_id=tenant_id, | |||
| token=credentials.get('config').get('api_key') | |||
| ) | |||
| firecrawl_app = FirecrawlApp(api_key=api_key, | |||
| base_url=credentials.get('config').get('base_url', None)) | |||
| api_key = encrypter.decrypt_token(tenant_id=tenant_id, token=credentials.get("config").get("api_key")) | |||
| firecrawl_app = FirecrawlApp(api_key=api_key, base_url=credentials.get("config").get("base_url", None)) | |||
| result = firecrawl_app.check_crawl_status(job_id) | |||
| if result.get('status') != 'completed': | |||
| raise ValueError('Crawl job is not completed') | |||
| data = result.get('data') | |||
| if result.get("status") != "completed": | |||
| raise ValueError("Crawl job is not completed") | |||
| data = result.get("data") | |||
| if data: | |||
| for item in data: | |||
| if item.get('source_url') == url: | |||
| if item.get("source_url") == url: | |||
| return item | |||
| return None | |||
| else: | |||
| raise ValueError('Invalid provider') | |||
| raise ValueError("Invalid provider") | |||
| @classmethod | |||
| def get_scrape_url_data(cls, provider: str, url: str, tenant_id: str, only_main_content: bool) -> dict | None: | |||
| credentials = ApiKeyAuthService.get_auth_credentials(tenant_id, | |||
| 'website', | |||
| provider) | |||
| if provider == 'firecrawl': | |||
| credentials = ApiKeyAuthService.get_auth_credentials(tenant_id, "website", provider) | |||
| if provider == "firecrawl": | |||
| # decrypt api_key | |||
| api_key = encrypter.decrypt_token( | |||
| tenant_id=tenant_id, | |||
| token=credentials.get('config').get('api_key') | |||
| ) | |||
| firecrawl_app = FirecrawlApp(api_key=api_key, | |||
| base_url=credentials.get('config').get('base_url', None)) | |||
| params = { | |||
| 'pageOptions': { | |||
| 'onlyMainContent': only_main_content, | |||
| "includeHtml": False | |||
| } | |||
| } | |||
| api_key = encrypter.decrypt_token(tenant_id=tenant_id, token=credentials.get("config").get("api_key")) | |||
| firecrawl_app = FirecrawlApp(api_key=api_key, base_url=credentials.get("config").get("base_url", None)) | |||
| params = {"pageOptions": {"onlyMainContent": only_main_content, "includeHtml": False}} | |||
| result = firecrawl_app.scrape_url(url, params) | |||
| return result | |||
| else: | |||
| raise ValueError('Invalid provider') | |||
| raise ValueError("Invalid provider") | |||
| @@ -10,7 +10,6 @@ from models.workflow import WorkflowAppLog, WorkflowRun, WorkflowRunStatus | |||
| class WorkflowAppService: | |||
| def get_paginate_workflow_app_logs(self, app_model: App, args: dict) -> Pagination: | |||
| """ | |||
| Get paginate workflow app logs | |||
| @@ -18,20 +17,14 @@ class WorkflowAppService: | |||
| :param args: request args | |||
| :return: | |||
| """ | |||
| query = ( | |||
| db.select(WorkflowAppLog) | |||
| .where( | |||
| WorkflowAppLog.tenant_id == app_model.tenant_id, | |||
| WorkflowAppLog.app_id == app_model.id | |||
| ) | |||
| query = db.select(WorkflowAppLog).where( | |||
| WorkflowAppLog.tenant_id == app_model.tenant_id, WorkflowAppLog.app_id == app_model.id | |||
| ) | |||
| status = WorkflowRunStatus.value_of(args.get('status')) if args.get('status') else None | |||
| keyword = args['keyword'] | |||
| status = WorkflowRunStatus.value_of(args.get("status")) if args.get("status") else None | |||
| keyword = args["keyword"] | |||
| if keyword or status: | |||
| query = query.join( | |||
| WorkflowRun, WorkflowRun.id == WorkflowAppLog.workflow_run_id | |||
| ) | |||
| query = query.join(WorkflowRun, WorkflowRun.id == WorkflowAppLog.workflow_run_id) | |||
| if keyword: | |||
| keyword_like_val = f"%{args['keyword'][:30]}%" | |||
| @@ -39,7 +32,7 @@ class WorkflowAppService: | |||
| WorkflowRun.inputs.ilike(keyword_like_val), | |||
| WorkflowRun.outputs.ilike(keyword_like_val), | |||
| # filter keyword by end user session id if created by end user role | |||
| and_(WorkflowRun.created_by_role == 'end_user', EndUser.session_id.ilike(keyword_like_val)) | |||
| and_(WorkflowRun.created_by_role == "end_user", EndUser.session_id.ilike(keyword_like_val)), | |||
| ] | |||
| # filter keyword by workflow run id | |||
| @@ -49,23 +42,16 @@ class WorkflowAppService: | |||
| query = query.outerjoin( | |||
| EndUser, | |||
| and_(WorkflowRun.created_by == EndUser.id, WorkflowRun.created_by_role == CreatedByRole.END_USER.value) | |||
| and_(WorkflowRun.created_by == EndUser.id, WorkflowRun.created_by_role == CreatedByRole.END_USER.value), | |||
| ).filter(or_(*keyword_conditions)) | |||
| if status: | |||
| # join with workflow_run and filter by status | |||
| query = query.filter( | |||
| WorkflowRun.status == status.value | |||
| ) | |||
| query = query.filter(WorkflowRun.status == status.value) | |||
| query = query.order_by(WorkflowAppLog.created_at.desc()) | |||
| pagination = db.paginate( | |||
| query, | |||
| page=args['page'], | |||
| per_page=args['limit'], | |||
| error_out=False | |||
| ) | |||
| pagination = db.paginate(query, page=args["page"], per_page=args["limit"], error_out=False) | |||
| return pagination | |||
| @@ -18,6 +18,7 @@ class WorkflowRunService: | |||
| :param app_model: app model | |||
| :param args: request args | |||
| """ | |||
| class WorkflowWithMessage: | |||
| message_id: str | |||
| conversation_id: str | |||
| @@ -33,9 +34,7 @@ class WorkflowRunService: | |||
| with_message_workflow_runs = [] | |||
| for workflow_run in pagination.data: | |||
| message = workflow_run.message | |||
| with_message_workflow_run = WorkflowWithMessage( | |||
| workflow_run=workflow_run | |||
| ) | |||
| with_message_workflow_run = WorkflowWithMessage(workflow_run=workflow_run) | |||
| if message: | |||
| with_message_workflow_run.message_id = message.id | |||
| with_message_workflow_run.conversation_id = message.conversation_id | |||
| @@ -53,26 +52,30 @@ class WorkflowRunService: | |||
| :param app_model: app model | |||
| :param args: request args | |||
| """ | |||
| limit = int(args.get('limit', 20)) | |||
| limit = int(args.get("limit", 20)) | |||
| base_query = db.session.query(WorkflowRun).filter( | |||
| WorkflowRun.tenant_id == app_model.tenant_id, | |||
| WorkflowRun.app_id == app_model.id, | |||
| WorkflowRun.triggered_from == WorkflowRunTriggeredFrom.DEBUGGING.value | |||
| WorkflowRun.triggered_from == WorkflowRunTriggeredFrom.DEBUGGING.value, | |||
| ) | |||
| if args.get('last_id'): | |||
| if args.get("last_id"): | |||
| last_workflow_run = base_query.filter( | |||
| WorkflowRun.id == args.get('last_id'), | |||
| WorkflowRun.id == args.get("last_id"), | |||
| ).first() | |||
| if not last_workflow_run: | |||
| raise ValueError('Last workflow run not exists') | |||
| workflow_runs = base_query.filter( | |||
| WorkflowRun.created_at < last_workflow_run.created_at, | |||
| WorkflowRun.id != last_workflow_run.id | |||
| ).order_by(WorkflowRun.created_at.desc()).limit(limit).all() | |||
| raise ValueError("Last workflow run not exists") | |||
| workflow_runs = ( | |||
| base_query.filter( | |||
| WorkflowRun.created_at < last_workflow_run.created_at, WorkflowRun.id != last_workflow_run.id | |||
| ) | |||
| .order_by(WorkflowRun.created_at.desc()) | |||
| .limit(limit) | |||
| .all() | |||
| ) | |||
| else: | |||
| workflow_runs = base_query.order_by(WorkflowRun.created_at.desc()).limit(limit).all() | |||
| @@ -81,17 +84,13 @@ class WorkflowRunService: | |||
| current_page_first_workflow_run = workflow_runs[-1] | |||
| rest_count = base_query.filter( | |||
| WorkflowRun.created_at < current_page_first_workflow_run.created_at, | |||
| WorkflowRun.id != current_page_first_workflow_run.id | |||
| WorkflowRun.id != current_page_first_workflow_run.id, | |||
| ).count() | |||
| if rest_count > 0: | |||
| has_more = True | |||
| return InfiniteScrollPagination( | |||
| data=workflow_runs, | |||
| limit=limit, | |||
| has_more=has_more | |||
| ) | |||
| return InfiniteScrollPagination(data=workflow_runs, limit=limit, has_more=has_more) | |||
| def get_workflow_run(self, app_model: App, run_id: str) -> WorkflowRun: | |||
| """ | |||
| @@ -100,11 +99,15 @@ class WorkflowRunService: | |||
| :param app_model: app model | |||
| :param run_id: workflow run id | |||
| """ | |||
| workflow_run = db.session.query(WorkflowRun).filter( | |||
| WorkflowRun.tenant_id == app_model.tenant_id, | |||
| WorkflowRun.app_id == app_model.id, | |||
| WorkflowRun.id == run_id, | |||
| ).first() | |||
| workflow_run = ( | |||
| db.session.query(WorkflowRun) | |||
| .filter( | |||
| WorkflowRun.tenant_id == app_model.tenant_id, | |||
| WorkflowRun.app_id == app_model.id, | |||
| WorkflowRun.id == run_id, | |||
| ) | |||
| .first() | |||
| ) | |||
| return workflow_run | |||
| @@ -117,12 +120,17 @@ class WorkflowRunService: | |||
| if not workflow_run: | |||
| return [] | |||
| node_executions = db.session.query(WorkflowNodeExecution).filter( | |||
| WorkflowNodeExecution.tenant_id == app_model.tenant_id, | |||
| WorkflowNodeExecution.app_id == app_model.id, | |||
| WorkflowNodeExecution.workflow_id == workflow_run.workflow_id, | |||
| WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value, | |||
| WorkflowNodeExecution.workflow_run_id == run_id, | |||
| ).order_by(WorkflowNodeExecution.index.desc()).all() | |||
| node_executions = ( | |||
| db.session.query(WorkflowNodeExecution) | |||
| .filter( | |||
| WorkflowNodeExecution.tenant_id == app_model.tenant_id, | |||
| WorkflowNodeExecution.app_id == app_model.id, | |||
| WorkflowNodeExecution.workflow_id == workflow_run.workflow_id, | |||
| WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value, | |||
| WorkflowNodeExecution.workflow_run_id == run_id, | |||
| ) | |||
| .order_by(WorkflowNodeExecution.index.desc()) | |||
| .all() | |||
| ) | |||
| return node_executions | |||
| @@ -37,11 +37,13 @@ class WorkflowService: | |||
| Get draft workflow | |||
| """ | |||
| # fetch draft workflow by app_model | |||
| workflow = db.session.query(Workflow).filter( | |||
| Workflow.tenant_id == app_model.tenant_id, | |||
| Workflow.app_id == app_model.id, | |||
| Workflow.version == 'draft' | |||
| ).first() | |||
| workflow = ( | |||
| db.session.query(Workflow) | |||
| .filter( | |||
| Workflow.tenant_id == app_model.tenant_id, Workflow.app_id == app_model.id, Workflow.version == "draft" | |||
| ) | |||
| .first() | |||
| ) | |||
| # return draft workflow | |||
| return workflow | |||
| @@ -55,11 +57,15 @@ class WorkflowService: | |||
| return None | |||
| # fetch published workflow by workflow_id | |||
| workflow = db.session.query(Workflow).filter( | |||
| Workflow.tenant_id == app_model.tenant_id, | |||
| Workflow.app_id == app_model.id, | |||
| Workflow.id == app_model.workflow_id | |||
| ).first() | |||
| workflow = ( | |||
| db.session.query(Workflow) | |||
| .filter( | |||
| Workflow.tenant_id == app_model.tenant_id, | |||
| Workflow.app_id == app_model.id, | |||
| Workflow.id == app_model.workflow_id, | |||
| ) | |||
| .first() | |||
| ) | |||
| return workflow | |||
| @@ -85,10 +91,7 @@ class WorkflowService: | |||
| raise WorkflowHashNotEqualError() | |||
| # validate features structure | |||
| self.validate_features_structure( | |||
| app_model=app_model, | |||
| features=features | |||
| ) | |||
| self.validate_features_structure(app_model=app_model, features=features) | |||
| # create draft workflow if not found | |||
| if not workflow: | |||
| @@ -96,7 +99,7 @@ class WorkflowService: | |||
| tenant_id=app_model.tenant_id, | |||
| app_id=app_model.id, | |||
| type=WorkflowType.from_app_mode(app_model.mode).value, | |||
| version='draft', | |||
| version="draft", | |||
| graph=json.dumps(graph), | |||
| features=json.dumps(features), | |||
| created_by=account.id, | |||
| @@ -122,9 +125,7 @@ class WorkflowService: | |||
| # return draft workflow | |||
| return workflow | |||
| def publish_workflow(self, app_model: App, | |||
| account: Account, | |||
| draft_workflow: Optional[Workflow] = None) -> Workflow: | |||
| def publish_workflow(self, app_model: App, account: Account, draft_workflow: Optional[Workflow] = None) -> Workflow: | |||
| """ | |||
| Publish workflow from draft | |||
| @@ -137,7 +138,7 @@ class WorkflowService: | |||
| draft_workflow = self.get_draft_workflow(app_model=app_model) | |||
| if not draft_workflow: | |||
| raise ValueError('No valid workflow found.') | |||
| raise ValueError("No valid workflow found.") | |||
| # create new workflow | |||
| workflow = Workflow( | |||
| @@ -187,17 +188,16 @@ class WorkflowService: | |||
| workflow_engine_manager = WorkflowEngineManager() | |||
| return workflow_engine_manager.get_default_config(node_type, filters) | |||
| def run_draft_workflow_node(self, app_model: App, | |||
| node_id: str, | |||
| user_inputs: dict, | |||
| account: Account) -> WorkflowNodeExecution: | |||
| def run_draft_workflow_node( | |||
| self, app_model: App, node_id: str, user_inputs: dict, account: Account | |||
| ) -> WorkflowNodeExecution: | |||
| """ | |||
| Run draft workflow node | |||
| """ | |||
| # fetch draft workflow by app_model | |||
| draft_workflow = self.get_draft_workflow(app_model=app_model) | |||
| if not draft_workflow: | |||
| raise ValueError('Workflow not initialized') | |||
| raise ValueError("Workflow not initialized") | |||
| # run draft workflow node | |||
| workflow_engine_manager = WorkflowEngineManager() | |||
| @@ -226,7 +226,7 @@ class WorkflowService: | |||
| created_by_role=CreatedByRole.ACCOUNT.value, | |||
| created_by=account.id, | |||
| created_at=datetime.now(timezone.utc).replace(tzinfo=None), | |||
| finished_at=datetime.now(timezone.utc).replace(tzinfo=None) | |||
| finished_at=datetime.now(timezone.utc).replace(tzinfo=None), | |||
| ) | |||
| db.session.add(workflow_node_execution) | |||
| db.session.commit() | |||
| @@ -247,14 +247,15 @@ class WorkflowService: | |||
| inputs=json.dumps(node_run_result.inputs) if node_run_result.inputs else None, | |||
| process_data=json.dumps(node_run_result.process_data) if node_run_result.process_data else None, | |||
| outputs=json.dumps(jsonable_encoder(node_run_result.outputs)) if node_run_result.outputs else None, | |||
| execution_metadata=(json.dumps(jsonable_encoder(node_run_result.metadata)) | |||
| if node_run_result.metadata else None), | |||
| execution_metadata=( | |||
| json.dumps(jsonable_encoder(node_run_result.metadata)) if node_run_result.metadata else None | |||
| ), | |||
| status=WorkflowNodeExecutionStatus.SUCCEEDED.value, | |||
| elapsed_time=time.perf_counter() - start_at, | |||
| created_by_role=CreatedByRole.ACCOUNT.value, | |||
| created_by=account.id, | |||
| created_at=datetime.now(timezone.utc).replace(tzinfo=None), | |||
| finished_at=datetime.now(timezone.utc).replace(tzinfo=None) | |||
| finished_at=datetime.now(timezone.utc).replace(tzinfo=None), | |||
| ) | |||
| else: | |||
| # create workflow node execution | |||
| @@ -273,7 +274,7 @@ class WorkflowService: | |||
| created_by_role=CreatedByRole.ACCOUNT.value, | |||
| created_by=account.id, | |||
| created_at=datetime.now(timezone.utc).replace(tzinfo=None), | |||
| finished_at=datetime.now(timezone.utc).replace(tzinfo=None) | |||
| finished_at=datetime.now(timezone.utc).replace(tzinfo=None), | |||
| ) | |||
| db.session.add(workflow_node_execution) | |||
| @@ -295,16 +296,16 @@ class WorkflowService: | |||
| workflow_converter = WorkflowConverter() | |||
| if app_model.mode not in [AppMode.CHAT.value, AppMode.COMPLETION.value]: | |||
| raise ValueError(f'Current App mode: {app_model.mode} is not supported convert to workflow.') | |||
| raise ValueError(f"Current App mode: {app_model.mode} is not supported convert to workflow.") | |||
| # convert to workflow | |||
| new_app = workflow_converter.convert_to_workflow( | |||
| app_model=app_model, | |||
| account=account, | |||
| name=args.get('name'), | |||
| icon_type=args.get('icon_type'), | |||
| icon=args.get('icon'), | |||
| icon_background=args.get('icon_background'), | |||
| name=args.get("name"), | |||
| icon_type=args.get("icon_type"), | |||
| icon=args.get("icon"), | |||
| icon_background=args.get("icon_background"), | |||
| ) | |||
| return new_app | |||
| @@ -312,15 +313,11 @@ class WorkflowService: | |||
| def validate_features_structure(self, app_model: App, features: dict) -> dict: | |||
| if app_model.mode == AppMode.ADVANCED_CHAT.value: | |||
| return AdvancedChatAppConfigManager.config_validate( | |||
| tenant_id=app_model.tenant_id, | |||
| config=features, | |||
| only_structure_validate=True | |||
| tenant_id=app_model.tenant_id, config=features, only_structure_validate=True | |||
| ) | |||
| elif app_model.mode == AppMode.WORKFLOW.value: | |||
| return WorkflowAppConfigManager.config_validate( | |||
| tenant_id=app_model.tenant_id, | |||
| config=features, | |||
| only_structure_validate=True | |||
| tenant_id=app_model.tenant_id, config=features, only_structure_validate=True | |||
| ) | |||
| else: | |||
| raise ValueError(f"Invalid app mode: {app_model.mode}") | |||
| @@ -1,4 +1,3 @@ | |||
| from flask_login import current_user | |||
| from configs import dify_config | |||
| @@ -14,34 +13,40 @@ class WorkspaceService: | |||
| if not tenant: | |||
| return None | |||
| tenant_info = { | |||
| 'id': tenant.id, | |||
| 'name': tenant.name, | |||
| 'plan': tenant.plan, | |||
| 'status': tenant.status, | |||
| 'created_at': tenant.created_at, | |||
| 'in_trail': True, | |||
| 'trial_end_reason': None, | |||
| 'role': 'normal', | |||
| "id": tenant.id, | |||
| "name": tenant.name, | |||
| "plan": tenant.plan, | |||
| "status": tenant.status, | |||
| "created_at": tenant.created_at, | |||
| "in_trail": True, | |||
| "trial_end_reason": None, | |||
| "role": "normal", | |||
| } | |||
| # Get role of user | |||
| tenant_account_join = db.session.query(TenantAccountJoin).filter( | |||
| TenantAccountJoin.tenant_id == tenant.id, | |||
| TenantAccountJoin.account_id == current_user.id | |||
| ).first() | |||
| tenant_info['role'] = tenant_account_join.role | |||
| can_replace_logo = FeatureService.get_features(tenant_info['id']).can_replace_logo | |||
| if can_replace_logo and TenantService.has_roles(tenant, | |||
| [TenantAccountJoinRole.OWNER, TenantAccountJoinRole.ADMIN]): | |||
| tenant_account_join = ( | |||
| db.session.query(TenantAccountJoin) | |||
| .filter(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == current_user.id) | |||
| .first() | |||
| ) | |||
| tenant_info["role"] = tenant_account_join.role | |||
| can_replace_logo = FeatureService.get_features(tenant_info["id"]).can_replace_logo | |||
| if can_replace_logo and TenantService.has_roles( | |||
| tenant, [TenantAccountJoinRole.OWNER, TenantAccountJoinRole.ADMIN] | |||
| ): | |||
| base_url = dify_config.FILES_URL | |||
| replace_webapp_logo = f'{base_url}/files/workspaces/{tenant.id}/webapp-logo' if tenant.custom_config_dict.get('replace_webapp_logo') else None | |||
| remove_webapp_brand = tenant.custom_config_dict.get('remove_webapp_brand', False) | |||
| tenant_info['custom_config'] = { | |||
| 'remove_webapp_brand': remove_webapp_brand, | |||
| 'replace_webapp_logo': replace_webapp_logo, | |||
| replace_webapp_logo = ( | |||
| f"{base_url}/files/workspaces/{tenant.id}/webapp-logo" | |||
| if tenant.custom_config_dict.get("replace_webapp_logo") | |||
| else None | |||
| ) | |||
| remove_webapp_brand = tenant.custom_config_dict.get("remove_webapp_brand", False) | |||
| tenant_info["custom_config"] = { | |||
| "remove_webapp_brand": remove_webapp_brand, | |||
| "replace_webapp_logo": replace_webapp_logo, | |||
| } | |||
| return tenant_info | |||