| @@ -1,34 +0,0 @@ | |||
| name: Setup UV and Python | |||
| inputs: | |||
| python-version: | |||
| description: Python version to use and the UV installed with | |||
| required: true | |||
| default: '3.12' | |||
| uv-version: | |||
| description: UV version to set up | |||
| required: true | |||
| default: '0.8.9' | |||
| uv-lockfile: | |||
| description: Path to the UV lockfile to restore cache from | |||
| required: true | |||
| default: '' | |||
| enable-cache: | |||
| required: true | |||
| default: true | |||
| runs: | |||
| using: composite | |||
| steps: | |||
| - name: Set up Python ${{ inputs.python-version }} | |||
| uses: actions/setup-python@v5 | |||
| with: | |||
| python-version: ${{ inputs.python-version }} | |||
| - name: Install uv | |||
| uses: astral-sh/setup-uv@v6 | |||
| with: | |||
| version: ${{ inputs.uv-version }} | |||
| python-version: ${{ inputs.python-version }} | |||
| enable-cache: ${{ inputs.enable-cache }} | |||
| cache-dependency-glob: ${{ inputs.uv-lockfile }} | |||
| @@ -33,10 +33,11 @@ jobs: | |||
| persist-credentials: false | |||
| - name: Setup UV and Python | |||
| uses: ./.github/actions/setup-uv | |||
| uses: astral-sh/setup-uv@v6 | |||
| with: | |||
| enable-cache: true | |||
| python-version: ${{ matrix.python-version }} | |||
| uv-lockfile: api/uv.lock | |||
| cache-dependency-glob: api/uv.lock | |||
| - name: Check UV lockfile | |||
| run: uv lock --project api --check | |||
| @@ -2,6 +2,7 @@ name: autofix.ci | |||
| on: | |||
| workflow_call: | |||
| pull_request: | |||
| branches: [ "main" ] | |||
| push: | |||
| branches: [ "main" ] | |||
| permissions: | |||
| @@ -15,7 +16,9 @@ jobs: | |||
| - uses: actions/checkout@v4 | |||
| # Use uv to ensure we have the same ruff version in CI and locally. | |||
| - uses: astral-sh/setup-uv@7edac99f961f18b581bbd960d59d049f04c0002f | |||
| - uses: astral-sh/setup-uv@v6 | |||
| with: | |||
| python-version: "3.12" | |||
| - run: | | |||
| cd api | |||
| uv sync --dev | |||
| @@ -25,9 +25,11 @@ jobs: | |||
| persist-credentials: false | |||
| - name: Setup UV and Python | |||
| uses: ./.github/actions/setup-uv | |||
| uses: astral-sh/setup-uv@v6 | |||
| with: | |||
| uv-lockfile: api/uv.lock | |||
| enable-cache: true | |||
| python-version: "3.12" | |||
| cache-dependency-glob: api/uv.lock | |||
| - name: Install dependencies | |||
| run: uv sync --project api | |||
| @@ -36,10 +36,11 @@ jobs: | |||
| - name: Setup UV and Python | |||
| if: steps.changed-files.outputs.any_changed == 'true' | |||
| uses: ./.github/actions/setup-uv | |||
| uses: astral-sh/setup-uv@v6 | |||
| with: | |||
| uv-lockfile: api/uv.lock | |||
| enable-cache: false | |||
| python-version: "3.12" | |||
| cache-dependency-glob: api/uv.lock | |||
| - name: Install dependencies | |||
| if: steps.changed-files.outputs.any_changed == 'true' | |||
| @@ -39,10 +39,11 @@ jobs: | |||
| remove_tool_cache: true | |||
| - name: Setup UV and Python | |||
| uses: ./.github/actions/setup-uv | |||
| uses: astral-sh/setup-uv@v6 | |||
| with: | |||
| enable-cache: true | |||
| python-version: ${{ matrix.python-version }} | |||
| uv-lockfile: api/uv.lock | |||
| cache-dependency-glob: api/uv.lock | |||
| - name: Check UV lockfile | |||
| run: uv lock --project api --check | |||
| @@ -0,0 +1 @@ | |||
| CLAUDE.md | |||
| @@ -180,7 +180,7 @@ docker compose up -d | |||
| ## Contributing | |||
| 对于那些想要贡献代码的人,请参阅我们的[贡献指南](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md)。 | |||
| 对于那些想要贡献代码的人,请参阅我们的[贡献指南](https://github.com/langgenius/dify/blob/main/CONTRIBUTING_CN.md)。 | |||
| 同时,请考虑通过社交媒体、活动和会议来支持 Dify 的分享。 | |||
| > 我们正在寻找贡献者来帮助将 Dify 翻译成除了中文和英文之外的其他语言。如果您有兴趣帮助,请参阅我们的[i18n README](https://github.com/langgenius/dify/blob/main/web/i18n-config/README.md)获取更多信息,并在我们的[Discord 社区服务器](https://discord.gg/8Tpq4AcN9c)的`global-users`频道中留言。 | |||
| @@ -173,7 +173,7 @@ Stellen Sie Dify mit einem Klick in AKS bereit, indem Sie [Azure Devops Pipeline | |||
| ## Contributing | |||
| Falls Sie Code beitragen möchten, lesen Sie bitte unseren [Contribution Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md). Gleichzeitig bitten wir Sie, Dify zu unterstützen, indem Sie es in den sozialen Medien teilen und auf Veranstaltungen und Konferenzen präsentieren. | |||
| Falls Sie Code beitragen möchten, lesen Sie bitte unseren [Contribution Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING_DE.md). Gleichzeitig bitten wir Sie, Dify zu unterstützen, indem Sie es in den sozialen Medien teilen und auf Veranstaltungen und Konferenzen präsentieren. | |||
| > Wir suchen Mitwirkende, die dabei helfen, Dify in weitere Sprachen zu übersetzen – außer Mandarin oder Englisch. Wenn Sie Interesse an einer Mitarbeit haben, lesen Sie bitte die [i18n README](https://github.com/langgenius/dify/blob/main/web/i18n-config/README.md) für weitere Informationen und hinterlassen Sie einen Kommentar im `global-users`-Kanal unseres [Discord Community Servers](https://discord.gg/8Tpq4AcN9c). | |||
| @@ -170,7 +170,7 @@ Implementa Dify en AKS con un clic usando [Azure Devops Pipeline Helm Chart by @ | |||
| ## Contribuir | |||
| Para aquellos que deseen contribuir con código, consulten nuestra [Guía de contribución](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md). | |||
| Para aquellos que deseen contribuir con código, consulten nuestra [Guía de contribución](https://github.com/langgenius/dify/blob/main/CONTRIBUTING_ES.md). | |||
| Al mismo tiempo, considera apoyar a Dify compartiéndolo en redes sociales y en eventos y conferencias. | |||
| > Estamos buscando colaboradores para ayudar con la traducción de Dify a idiomas que no sean el mandarín o el inglés. Si estás interesado en ayudar, consulta el [README de i18n](https://github.com/langgenius/dify/blob/main/web/i18n-config/README.md) para obtener más información y déjanos un comentario en el canal `global-users` de nuestro [Servidor de Comunidad en Discord](https://discord.gg/8Tpq4AcN9c). | |||
| @@ -168,7 +168,7 @@ Déployez Dify sur AKS en un clic en utilisant [Azure Devops Pipeline Helm Chart | |||
| ## Contribuer | |||
| Pour ceux qui souhaitent contribuer du code, consultez notre [Guide de contribution](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md). | |||
| Pour ceux qui souhaitent contribuer du code, consultez notre [Guide de contribution](https://github.com/langgenius/dify/blob/main/CONTRIBUTING_FR.md). | |||
| Dans le même temps, veuillez envisager de soutenir Dify en le partageant sur les réseaux sociaux et lors d'événements et de conférences. | |||
| > Nous recherchons des contributeurs pour aider à traduire Dify dans des langues autres que le mandarin ou l'anglais. Si vous êtes intéressé à aider, veuillez consulter le [README i18n](https://github.com/langgenius/dify/blob/main/web/i18n-config/README.md) pour plus d'informations, et laissez-nous un commentaire dans le canal `global-users` de notre [Serveur communautaire Discord](https://discord.gg/8Tpq4AcN9c). | |||
| @@ -169,7 +169,7 @@ docker compose up -d | |||
| ## 貢献 | |||
| コードに貢献したい方は、[Contribution Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md)を参照してください。 | |||
| コードに貢献したい方は、[Contribution Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING_JA.md)を参照してください。 | |||
| 同時に、DifyをSNSやイベント、カンファレンスで共有してサポートしていただけると幸いです。 | |||
| > Difyを英語または中国語以外の言語に翻訳してくれる貢献者を募集しています。興味がある場合は、詳細については[i18n README](https://github.com/langgenius/dify/blob/main/web/i18n-config/README.md)を参照してください。また、[Discordコミュニティサーバー](https://discord.gg/8Tpq4AcN9c)の`global-users`チャンネルにコメントを残してください。 | |||
| @@ -162,7 +162,7 @@ Dify를 Kubernetes에 배포하고 프리미엄 스케일링 설정을 구성했 | |||
| ## 기여 | |||
| 코드에 기여하고 싶은 분들은 [기여 가이드](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md)를 참조하세요. | |||
| 코드에 기여하고 싶은 분들은 [기여 가이드](https://github.com/langgenius/dify/blob/main/CONTRIBUTING_KR.md)를 참조하세요. | |||
| 동시에 Dify를 소셜 미디어와 행사 및 컨퍼런스에 공유하여 지원하는 것을 고려해 주시기 바랍니다. | |||
| > 우리는 Dify를 중국어나 영어 이외의 언어로 번역하는 데 도움을 줄 수 있는 기여자를 찾고 있습니다. 도움을 주고 싶으시다면 [i18n README](https://github.com/langgenius/dify/blob/main/web/i18n-config/README.md)에서 더 많은 정보를 확인하시고 [Discord 커뮤니티 서버](https://discord.gg/8Tpq4AcN9c)의 `global-users` 채널에 댓글을 남겨주세요. | |||
| @@ -168,7 +168,7 @@ Implante o Dify no AKS com um clique usando [Azure Devops Pipeline Helm Chart by | |||
| ## Contribuindo | |||
| Para aqueles que desejam contribuir com código, veja nosso [Guia de Contribuição](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md). | |||
| Para aqueles que desejam contribuir com código, veja nosso [Guia de Contribuição](https://github.com/langgenius/dify/blob/main/CONTRIBUTING_PT.md). | |||
| Ao mesmo tempo, considere apoiar o Dify compartilhando-o nas redes sociais e em eventos e conferências. | |||
| > Estamos buscando contribuidores para ajudar na tradução do Dify para idiomas além de Mandarim e Inglês. Se você tiver interesse em ajudar, consulte o [README i18n](https://github.com/langgenius/dify/blob/main/web/i18n-config/README.md) para mais informações e deixe-nos um comentário no canal `global-users` em nosso [Servidor da Comunidade no Discord](https://discord.gg/8Tpq4AcN9c). | |||
| @@ -161,7 +161,7 @@ Dify'ı bulut platformuna tek tıklamayla dağıtın [terraform](https://www.ter | |||
| ## Katkıda Bulunma | |||
| Kod katkısında bulunmak isteyenler için [Katkı Kılavuzumuza](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md) bakabilirsiniz. | |||
| Kod katkısında bulunmak isteyenler için [Katkı Kılavuzumuza](https://github.com/langgenius/dify/blob/main/CONTRIBUTING_TR.md) bakabilirsiniz. | |||
| Aynı zamanda, lütfen Dify'ı sosyal medyada, etkinliklerde ve konferanslarda paylaşarak desteklemeyi düşünün. | |||
| > Dify'ı Mandarin veya İngilizce dışındaki dillere çevirmemize yardımcı olacak katkıda bulunanlara ihtiyacımız var. Yardımcı olmakla ilgileniyorsanız, lütfen daha fazla bilgi için [i18n README](https://github.com/langgenius/dify/blob/main/web/i18n-config/README.md) dosyasına bakın ve [Discord Topluluk Sunucumuzdaki](https://discord.gg/8Tpq4AcN9c) `global-users` kanalında bize bir yorum bırakın. | |||
| @@ -173,7 +173,7 @@ Dify 的所有功能都提供相應的 API,因此您可以輕鬆地將 Dify | |||
| ## 貢獻 | |||
| 對於想要貢獻程式碼的開發者,請參閱我們的[貢獻指南](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md)。 | |||
| 對於想要貢獻程式碼的開發者,請參閱我們的[貢獻指南](https://github.com/langgenius/dify/blob/main/CONTRIBUTING_TW.md)。 | |||
| 同時,也請考慮透過在社群媒體和各種活動與會議上分享 Dify 來支持我們。 | |||
| > 我們正在尋找貢獻者協助將 Dify 翻譯成中文和英文以外的語言。如果您有興趣幫忙,請查看 [i18n README](https://github.com/langgenius/dify/blob/main/web/i18n-config/README.md) 獲取更多資訊,並在我們的 [Discord 社群伺服器](https://discord.gg/8Tpq4AcN9c) 的 `global-users` 頻道留言給我們。 | |||
| @@ -162,7 +162,7 @@ Triển khai Dify lên AKS chỉ với một cú nhấp chuột bằng [Azure De | |||
| ## Đóng góp | |||
| Đối với những người muốn đóng góp mã, xem [Hướng dẫn Đóng góp](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md) của chúng tôi. | |||
| Đối với những người muốn đóng góp mã, xem [Hướng dẫn Đóng góp](https://github.com/langgenius/dify/blob/main/CONTRIBUTING_VI.md) của chúng tôi. | |||
| Đồng thời, vui lòng xem xét hỗ trợ Dify bằng cách chia sẻ nó trên mạng xã hội và tại các sự kiện và hội nghị. | |||
| > Chúng tôi đang tìm kiếm người đóng góp để giúp dịch Dify sang các ngôn ngữ khác ngoài tiếng Trung hoặc tiếng Anh. Nếu bạn quan tâm đến việc giúp đỡ, vui lòng xem [README i18n](https://github.com/langgenius/dify/blob/main/web/i18n-config/README.md) để biết thêm thông tin và để lại bình luận cho chúng tôi trong kênh `global-users` của [Máy chủ Cộng đồng Discord](https://discord.gg/8Tpq4AcN9c) của chúng tôi. | |||
| @@ -564,3 +564,7 @@ QUEUE_MONITOR_THRESHOLD=200 | |||
| QUEUE_MONITOR_ALERT_EMAILS= | |||
| # Monitor interval in minutes, default is 30 minutes | |||
| QUEUE_MONITOR_INTERVAL=30 | |||
| # Swagger UI configuration | |||
| SWAGGER_UI_ENABLED=true | |||
| SWAGGER_UI_PATH=/swagger-ui.html | |||
| @@ -43,6 +43,7 @@ select = [ | |||
| "S302", # suspicious-marshal-usage, disallow use of `marshal` module | |||
| "S311", # suspicious-non-cryptographic-random-usage | |||
| "G001", # don't use str format to logging messages | |||
| "G003", # don't use + in logging messages | |||
| "G004", # don't use f-strings to format logging messages | |||
| ] | |||
| @@ -99,14 +99,14 @@ uv run celery -A app.celery beat | |||
| 1. Run the tests locally with mocked system environment variables in `tool.pytest_env` section in `pyproject.toml`, more can check [Claude.md](../CLAUDE.md) | |||
| ```cli | |||
| uv run --project api pytest # Run all tests | |||
| uv run --project api pytest tests/unit_tests/ # Unit tests only | |||
| uv run --project api pytest tests/integration_tests/ # Integration tests | |||
| ```bash | |||
| uv run pytest # Run all tests | |||
| uv run pytest tests/unit_tests/ # Unit tests only | |||
| uv run pytest tests/integration_tests/ # Integration tests | |||
| # Code quality | |||
| ./dev/reformat # Run all formatters and linters | |||
| uv run --project api ruff check --fix ./ # Fix linting issues | |||
| uv run --project api ruff format ./ # Format code | |||
| uv run --project api mypy . # Type checking | |||
| ../dev/reformat # Run all formatters and linters | |||
| uv run ruff check --fix ./ # Fix linting issues | |||
| uv run ruff format ./ # Format code | |||
| uv run mypy . # Type checking | |||
| ``` | |||
| @@ -0,0 +1,11 @@ | |||
| from tests.integration_tests.utils.parent_class import ParentClass | |||
| class ChildClass(ParentClass): | |||
| """Test child class for module import helper tests""" | |||
| def __init__(self, name): | |||
| super().__init__(name) | |||
| def get_name(self): | |||
| return f"Child: {self.name}" | |||
| @@ -43,6 +43,8 @@ from services.plugin.data_migration import PluginDataMigration | |||
| from services.plugin.plugin_migration import PluginMigration | |||
| from tasks.remove_app_and_related_data_task import delete_draft_variables_batch | |||
| logger = logging.getLogger(__name__) | |||
| @click.command("reset-password", help="Reset the account password.") | |||
| @click.option("--email", prompt=True, help="Account email to reset password for") | |||
| @@ -690,7 +692,7 @@ def upgrade_db(): | |||
| click.echo(click.style("Database migration successful!", fg="green")) | |||
| except Exception: | |||
| logging.exception("Failed to execute database migration") | |||
| logger.exception("Failed to execute database migration") | |||
| finally: | |||
| lock.release() | |||
| else: | |||
| @@ -738,7 +740,7 @@ where sites.id is null limit 1000""" | |||
| except Exception: | |||
| failed_app_ids.append(app_id) | |||
| click.echo(click.style(f"Failed to fix missing site for app {app_id}", fg="red")) | |||
| logging.exception("Failed to fix app related site missing issue, app_id: %s", app_id) | |||
| logger.exception("Failed to fix app related site missing issue, app_id: %s", app_id) | |||
| continue | |||
| if not processed_count: | |||
| @@ -1,4 +1,4 @@ | |||
| from typing import Annotated, Literal, Optional | |||
| from typing import Literal, Optional | |||
| from pydantic import ( | |||
| AliasChoices, | |||
| @@ -976,6 +976,18 @@ class WorkflowLogConfig(BaseSettings): | |||
| ) | |||
| class SwaggerUIConfig(BaseSettings): | |||
| SWAGGER_UI_ENABLED: bool = Field( | |||
| description="Whether to enable Swagger UI in api module", | |||
| default=True, | |||
| ) | |||
| SWAGGER_UI_PATH: str = Field( | |||
| description="Swagger UI page path in api module", | |||
| default="/swagger-ui.html", | |||
| ) | |||
| class FeatureConfig( | |||
| # place the configs in alphabet order | |||
| AppExecutionConfig, | |||
| @@ -1007,6 +1019,7 @@ class FeatureConfig( | |||
| WorkspaceConfig, | |||
| LoginConfig, | |||
| AccountConfig, | |||
| SwaggerUIConfig, | |||
| # hosted services config | |||
| HostedServiceConfig, | |||
| CeleryBeatConfig, | |||
| @@ -215,6 +215,7 @@ class DatabaseConfig(BaseSettings): | |||
| "pool_pre_ping": self.SQLALCHEMY_POOL_PRE_PING, | |||
| "connect_args": connect_args, | |||
| "pool_use_lifo": self.SQLALCHEMY_POOL_USE_LIFO, | |||
| "pool_reset_on_return": None, | |||
| } | |||
| @@ -31,6 +31,8 @@ from services.errors.audio import ( | |||
| UnsupportedAudioTypeServiceError, | |||
| ) | |||
| logger = logging.getLogger(__name__) | |||
| class ChatMessageAudioApi(Resource): | |||
| @setup_required | |||
| @@ -49,7 +51,7 @@ class ChatMessageAudioApi(Resource): | |||
| return response | |||
| except services.errors.app_model_config.AppModelConfigBrokenError: | |||
| logging.exception("App model config broken.") | |||
| logger.exception("App model config broken.") | |||
| raise AppUnavailableError() | |||
| except NoAudioUploadedServiceError: | |||
| raise NoAudioUploadedError() | |||
| @@ -70,7 +72,7 @@ class ChatMessageAudioApi(Resource): | |||
| except ValueError as e: | |||
| raise e | |||
| except Exception as e: | |||
| logging.exception("Failed to handle post request to ChatMessageAudioApi") | |||
| logger.exception("Failed to handle post request to ChatMessageAudioApi") | |||
| raise InternalServerError() | |||
| @@ -97,7 +99,7 @@ class ChatMessageTextApi(Resource): | |||
| ) | |||
| return response | |||
| except services.errors.app_model_config.AppModelConfigBrokenError: | |||
| logging.exception("App model config broken.") | |||
| logger.exception("App model config broken.") | |||
| raise AppUnavailableError() | |||
| except NoAudioUploadedServiceError: | |||
| raise NoAudioUploadedError() | |||
| @@ -118,7 +120,7 @@ class ChatMessageTextApi(Resource): | |||
| except ValueError as e: | |||
| raise e | |||
| except Exception as e: | |||
| logging.exception("Failed to handle post request to ChatMessageTextApi") | |||
| logger.exception("Failed to handle post request to ChatMessageTextApi") | |||
| raise InternalServerError() | |||
| @@ -160,7 +162,7 @@ class TextModesApi(Resource): | |||
| except ValueError as e: | |||
| raise e | |||
| except Exception as e: | |||
| logging.exception("Failed to handle get request to TextModesApi") | |||
| logger.exception("Failed to handle get request to TextModesApi") | |||
| raise InternalServerError() | |||
| @@ -34,6 +34,8 @@ from models.model import AppMode | |||
| from services.app_generate_service import AppGenerateService | |||
| from services.errors.llm import InvokeRateLimitError | |||
| logger = logging.getLogger(__name__) | |||
| # define completion message api for user | |||
| class CompletionMessageApi(Resource): | |||
| @@ -67,7 +69,7 @@ class CompletionMessageApi(Resource): | |||
| except services.errors.conversation.ConversationCompletedError: | |||
| raise ConversationCompletedError() | |||
| except services.errors.app_model_config.AppModelConfigBrokenError: | |||
| logging.exception("App model config broken.") | |||
| logger.exception("App model config broken.") | |||
| raise AppUnavailableError() | |||
| except ProviderTokenNotInitError as ex: | |||
| raise ProviderNotInitializeError(ex.description) | |||
| @@ -80,7 +82,7 @@ class CompletionMessageApi(Resource): | |||
| except ValueError as e: | |||
| raise e | |||
| except Exception as e: | |||
| logging.exception("internal server error.") | |||
| logger.exception("internal server error.") | |||
| raise InternalServerError() | |||
| @@ -134,7 +136,7 @@ class ChatMessageApi(Resource): | |||
| except services.errors.conversation.ConversationCompletedError: | |||
| raise ConversationCompletedError() | |||
| except services.errors.app_model_config.AppModelConfigBrokenError: | |||
| logging.exception("App model config broken.") | |||
| logger.exception("App model config broken.") | |||
| raise AppUnavailableError() | |||
| except ProviderTokenNotInitError as ex: | |||
| raise ProviderNotInitializeError(ex.description) | |||
| @@ -149,7 +151,7 @@ class ChatMessageApi(Resource): | |||
| except ValueError as e: | |||
| raise e | |||
| except Exception as e: | |||
| logging.exception("internal server error.") | |||
| logger.exception("internal server error.") | |||
| raise InternalServerError() | |||
| @@ -33,6 +33,8 @@ from services.errors.conversation import ConversationNotExistsError | |||
| from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError | |||
| from services.message_service import MessageService | |||
| logger = logging.getLogger(__name__) | |||
| class ChatMessageListApi(Resource): | |||
| message_infinite_scroll_pagination_fields = { | |||
| @@ -215,7 +217,7 @@ class MessageSuggestedQuestionApi(Resource): | |||
| except SuggestedQuestionsAfterAnswerDisabledError: | |||
| raise AppSuggestedQuestionsAfterAnswerDisabledError() | |||
| except Exception: | |||
| logging.exception("internal server error.") | |||
| logger.exception("internal server error.") | |||
| raise InternalServerError() | |||
| return {"data": questions} | |||
| @@ -72,6 +72,7 @@ class DraftWorkflowApi(Resource): | |||
| Get draft workflow | |||
| """ | |||
| # The role of the current user in the ta table must be admin, owner, or editor | |||
| assert isinstance(current_user, Account) | |||
| if not current_user.is_editor: | |||
| raise Forbidden() | |||
| @@ -94,6 +95,7 @@ class DraftWorkflowApi(Resource): | |||
| Sync draft workflow | |||
| """ | |||
| # The role of the current user in the ta table must be admin, owner, or editor | |||
| assert isinstance(current_user, Account) | |||
| if not current_user.is_editor: | |||
| raise Forbidden() | |||
| @@ -171,6 +173,7 @@ class AdvancedChatDraftWorkflowRunApi(Resource): | |||
| Run draft workflow | |||
| """ | |||
| # The role of the current user in the ta table must be admin, owner, or editor | |||
| assert isinstance(current_user, Account) | |||
| if not current_user.is_editor: | |||
| raise Forbidden() | |||
| @@ -205,7 +208,7 @@ class AdvancedChatDraftWorkflowRunApi(Resource): | |||
| except ValueError as e: | |||
| raise e | |||
| except Exception: | |||
| logging.exception("internal server error.") | |||
| logger.exception("internal server error.") | |||
| raise InternalServerError() | |||
| @@ -218,13 +221,12 @@ class AdvancedChatDraftRunIterationNodeApi(Resource): | |||
| """ | |||
| Run draft workflow iteration node | |||
| """ | |||
| if not isinstance(current_user, Account): | |||
| raise Forbidden() | |||
| # The role of the current user in the ta table must be admin, owner, or editor | |||
| if not current_user.is_editor: | |||
| raise Forbidden() | |||
| if not isinstance(current_user, Account): | |||
| raise Forbidden() | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("inputs", type=dict, location="json") | |||
| args = parser.parse_args() | |||
| @@ -242,7 +244,7 @@ class AdvancedChatDraftRunIterationNodeApi(Resource): | |||
| except ValueError as e: | |||
| raise e | |||
| except Exception: | |||
| logging.exception("internal server error.") | |||
| logger.exception("internal server error.") | |||
| raise InternalServerError() | |||
| @@ -256,11 +258,10 @@ class WorkflowDraftRunIterationNodeApi(Resource): | |||
| Run draft workflow iteration node | |||
| """ | |||
| # The role of the current user in the ta table must be admin, owner, or editor | |||
| if not current_user.is_editor: | |||
| raise Forbidden() | |||
| if not isinstance(current_user, Account): | |||
| raise Forbidden() | |||
| if not current_user.is_editor: | |||
| raise Forbidden() | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("inputs", type=dict, location="json") | |||
| @@ -279,7 +280,7 @@ class WorkflowDraftRunIterationNodeApi(Resource): | |||
| except ValueError as e: | |||
| raise e | |||
| except Exception: | |||
| logging.exception("internal server error.") | |||
| logger.exception("internal server error.") | |||
| raise InternalServerError() | |||
| @@ -292,12 +293,12 @@ class AdvancedChatDraftRunLoopNodeApi(Resource): | |||
| """ | |||
| Run draft workflow loop node | |||
| """ | |||
| # The role of the current user in the ta table must be admin, owner, or editor | |||
| if not current_user.is_editor: | |||
| raise Forbidden() | |||
| if not isinstance(current_user, Account): | |||
| raise Forbidden() | |||
| # The role of the current user in the ta table must be admin, owner, or editor | |||
| if not current_user.is_editor: | |||
| raise Forbidden() | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("inputs", type=dict, location="json") | |||
| @@ -316,7 +317,7 @@ class AdvancedChatDraftRunLoopNodeApi(Resource): | |||
| except ValueError as e: | |||
| raise e | |||
| except Exception: | |||
| logging.exception("internal server error.") | |||
| logger.exception("internal server error.") | |||
| raise InternalServerError() | |||
| @@ -329,12 +330,12 @@ class WorkflowDraftRunLoopNodeApi(Resource): | |||
| """ | |||
| Run draft workflow loop node | |||
| """ | |||
| # The role of the current user in the ta table must be admin, owner, or editor | |||
| if not current_user.is_editor: | |||
| raise Forbidden() | |||
| if not isinstance(current_user, Account): | |||
| raise Forbidden() | |||
| # The role of the current user in the ta table must be admin, owner, or editor | |||
| if not current_user.is_editor: | |||
| raise Forbidden() | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("inputs", type=dict, location="json") | |||
| @@ -353,7 +354,7 @@ class WorkflowDraftRunLoopNodeApi(Resource): | |||
| except ValueError as e: | |||
| raise e | |||
| except Exception: | |||
| logging.exception("internal server error.") | |||
| logger.exception("internal server error.") | |||
| raise InternalServerError() | |||
| @@ -366,12 +367,12 @@ class DraftWorkflowRunApi(Resource): | |||
| """ | |||
| Run draft workflow | |||
| """ | |||
| # The role of the current user in the ta table must be admin, owner, or editor | |||
| if not current_user.is_editor: | |||
| raise Forbidden() | |||
| if not isinstance(current_user, Account): | |||
| raise Forbidden() | |||
| # The role of the current user in the ta table must be admin, owner, or editor | |||
| if not current_user.is_editor: | |||
| raise Forbidden() | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") | |||
| @@ -405,6 +406,9 @@ class WorkflowTaskStopApi(Resource): | |||
| """ | |||
| Stop workflow task | |||
| """ | |||
| if not isinstance(current_user, Account): | |||
| raise Forbidden() | |||
| # The role of the current user in the ta table must be admin, owner, or editor | |||
| if not current_user.is_editor: | |||
| raise Forbidden() | |||
| @@ -424,12 +428,12 @@ class DraftWorkflowNodeRunApi(Resource): | |||
| """ | |||
| Run draft workflow node | |||
| """ | |||
| # The role of the current user in the ta table must be admin, owner, or editor | |||
| if not current_user.is_editor: | |||
| raise Forbidden() | |||
| if not isinstance(current_user, Account): | |||
| raise Forbidden() | |||
| # The role of the current user in the ta table must be admin, owner, or editor | |||
| if not current_user.is_editor: | |||
| raise Forbidden() | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") | |||
| @@ -472,6 +476,9 @@ class PublishedWorkflowApi(Resource): | |||
| """ | |||
| Get published workflow | |||
| """ | |||
| if not isinstance(current_user, Account): | |||
| raise Forbidden() | |||
| # The role of the current user in the ta table must be admin, owner, or editor | |||
| if not current_user.is_editor: | |||
| raise Forbidden() | |||
| @@ -491,13 +498,12 @@ class PublishedWorkflowApi(Resource): | |||
| """ | |||
| Publish workflow | |||
| """ | |||
| if not isinstance(current_user, Account): | |||
| raise Forbidden() | |||
| # The role of the current user in the ta table must be admin, owner, or editor | |||
| if not current_user.is_editor: | |||
| raise Forbidden() | |||
| if not isinstance(current_user, Account): | |||
| raise Forbidden() | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("marked_name", type=str, required=False, default="", location="json") | |||
| parser.add_argument("marked_comment", type=str, required=False, default="", location="json") | |||
| @@ -541,6 +547,9 @@ class DefaultBlockConfigsApi(Resource): | |||
| """ | |||
| Get default block config | |||
| """ | |||
| if not isinstance(current_user, Account): | |||
| raise Forbidden() | |||
| # The role of the current user in the ta table must be admin, owner, or editor | |||
| if not current_user.is_editor: | |||
| raise Forbidden() | |||
| @@ -559,13 +568,12 @@ class DefaultBlockConfigApi(Resource): | |||
| """ | |||
| Get default block config | |||
| """ | |||
| if not isinstance(current_user, Account): | |||
| raise Forbidden() | |||
| # The role of the current user in the ta table must be admin, owner, or editor | |||
| if not current_user.is_editor: | |||
| raise Forbidden() | |||
| if not isinstance(current_user, Account): | |||
| raise Forbidden() | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("q", type=str, location="args") | |||
| args = parser.parse_args() | |||
| @@ -595,13 +603,12 @@ class ConvertToWorkflowApi(Resource): | |||
| Convert expert mode of chatbot app to workflow mode | |||
| Convert Completion App to Workflow App | |||
| """ | |||
| if not isinstance(current_user, Account): | |||
| raise Forbidden() | |||
| # The role of the current user in the ta table must be admin, owner, or editor | |||
| if not current_user.is_editor: | |||
| raise Forbidden() | |||
| if not isinstance(current_user, Account): | |||
| raise Forbidden() | |||
| if request.data: | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("name", type=str, required=False, nullable=True, location="json") | |||
| @@ -645,6 +652,9 @@ class PublishedAllWorkflowApi(Resource): | |||
| """ | |||
| Get published workflows | |||
| """ | |||
| if not isinstance(current_user, Account): | |||
| raise Forbidden() | |||
| if not current_user.is_editor: | |||
| raise Forbidden() | |||
| @@ -693,13 +703,12 @@ class WorkflowByIdApi(Resource): | |||
| """ | |||
| Update workflow attributes | |||
| """ | |||
| if not isinstance(current_user, Account): | |||
| raise Forbidden() | |||
| # Check permission | |||
| if not current_user.is_editor: | |||
| raise Forbidden() | |||
| if not isinstance(current_user, Account): | |||
| raise Forbidden() | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("marked_name", type=str, required=False, location="json") | |||
| parser.add_argument("marked_comment", type=str, required=False, location="json") | |||
| @@ -750,13 +759,12 @@ class WorkflowByIdApi(Resource): | |||
| """ | |||
| Delete workflow | |||
| """ | |||
| if not isinstance(current_user, Account): | |||
| raise Forbidden() | |||
| # Check permission | |||
| if not current_user.is_editor: | |||
| raise Forbidden() | |||
| if not isinstance(current_user, Account): | |||
| raise Forbidden() | |||
| workflow_service = WorkflowService() | |||
| # Create a session and manage the transaction | |||
| @@ -21,6 +21,7 @@ from factories.file_factory import build_from_mapping, build_from_mappings | |||
| from factories.variable_factory import build_segment_with_type | |||
| from libs.login import current_user, login_required | |||
| from models import App, AppMode, db | |||
| from models.account import Account | |||
| from models.workflow import WorkflowDraftVariable | |||
| from services.workflow_draft_variable_service import WorkflowDraftVariableList, WorkflowDraftVariableService | |||
| from services.workflow_service import WorkflowService | |||
| @@ -135,6 +136,7 @@ def _api_prerequisite(f): | |||
| @account_initialization_required | |||
| @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) | |||
| def wrapper(*args, **kwargs): | |||
| assert isinstance(current_user, Account) | |||
| if not current_user.is_editor: | |||
| raise Forbidden() | |||
| return f(*args, **kwargs) | |||
| @@ -6,9 +6,11 @@ from controllers.console.app.error import AppNotFoundError | |||
| from extensions.ext_database import db | |||
| from libs.login import current_user | |||
| from models import App, AppMode | |||
| from models.account import Account | |||
| def _load_app_model(app_id: str) -> Optional[App]: | |||
| assert isinstance(current_user, Account) | |||
| app_model = ( | |||
| db.session.query(App) | |||
| .where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") | |||
| @@ -13,6 +13,8 @@ from libs.oauth_data_source import NotionOAuth | |||
| from ..wraps import account_initialization_required, setup_required | |||
| logger = logging.getLogger(__name__) | |||
| def get_oauth_providers(): | |||
| with current_app.app_context(): | |||
| @@ -80,7 +82,7 @@ class OAuthDataSourceBinding(Resource): | |||
| try: | |||
| oauth_provider.get_access_token(code) | |||
| except requests.exceptions.HTTPError as e: | |||
| logging.exception( | |||
| logger.exception( | |||
| "An error occurred during the OAuthCallback process with %s: %s", provider, e.response.text | |||
| ) | |||
| return {"error": "OAuth data source process failed"}, 400 | |||
| @@ -103,7 +105,7 @@ class OAuthDataSourceSync(Resource): | |||
| try: | |||
| oauth_provider.sync_data_source(binding_id) | |||
| except requests.exceptions.HTTPError as e: | |||
| logging.exception( | |||
| logger.exception( | |||
| "An error occurred during the OAuthCallback process with %s: %s", provider, e.response.text | |||
| ) | |||
| return {"error": "OAuth data source process failed"}, 400 | |||
| @@ -55,6 +55,12 @@ class EmailOrPasswordMismatchError(BaseHTTPException): | |||
| code = 400 | |||
| class AuthenticationFailedError(BaseHTTPException): | |||
| error_code = "authentication_failed" | |||
| description = "Invalid email or password." | |||
| code = 401 | |||
| class EmailPasswordLoginLimitError(BaseHTTPException): | |||
| error_code = "email_code_login_limit" | |||
| description = "Too many incorrect password attempts. Please try again later." | |||
| @@ -9,8 +9,8 @@ from configs import dify_config | |||
| from constants.languages import languages | |||
| from controllers.console import api | |||
| from controllers.console.auth.error import ( | |||
| AuthenticationFailedError, | |||
| EmailCodeError, | |||
| EmailOrPasswordMismatchError, | |||
| EmailPasswordLoginLimitError, | |||
| InvalidEmailError, | |||
| InvalidTokenError, | |||
| @@ -79,7 +79,7 @@ class LoginApi(Resource): | |||
| raise AccountBannedError() | |||
| except services.errors.account.AccountPasswordError: | |||
| AccountService.add_login_error_rate_limit(args["email"]) | |||
| raise EmailOrPasswordMismatchError() | |||
| raise AuthenticationFailedError() | |||
| except services.errors.account.AccountNotFoundError: | |||
| if FeatureService.get_system_features().is_allow_register: | |||
| token = AccountService.send_reset_password_email(email=args["email"], language=language) | |||
| @@ -132,6 +132,7 @@ class ResetPasswordSendEmailApi(Resource): | |||
| account = AccountService.get_user_through_email(args["email"]) | |||
| except AccountRegisterError as are: | |||
| raise AccountInFreezeError() | |||
| if account is None: | |||
| if FeatureService.get_system_features().is_allow_register: | |||
| token = AccountService.send_reset_password_email(email=args["email"], language=language) | |||
| @@ -24,6 +24,8 @@ from services.feature_service import FeatureService | |||
| from .. import api | |||
| logger = logging.getLogger(__name__) | |||
| def get_oauth_providers(): | |||
| with current_app.app_context(): | |||
| @@ -80,7 +82,7 @@ class OAuthCallback(Resource): | |||
| user_info = oauth_provider.get_user_info(token) | |||
| except requests.exceptions.RequestException as e: | |||
| error_text = e.response.text if e.response else str(e) | |||
| logging.exception("An error occurred during the OAuth process with %s: %s", provider, error_text) | |||
| logger.exception("An error occurred during the OAuth process with %s: %s", provider, error_text) | |||
| return {"error": "OAuth process failed"}, 400 | |||
| if invite_token and RegisterService.is_valid_invite_token(invite_token): | |||
| @@ -564,7 +564,7 @@ class DatasetIndexingStatusApi(Resource): | |||
| } | |||
| documents_status.append(marshal(document_dict, document_status_fields)) | |||
| data = {"data": documents_status} | |||
| return data | |||
| return data, 200 | |||
| class DatasetApiKeyApi(Resource): | |||
| @@ -56,6 +56,8 @@ from models.dataset import DocumentPipelineExecutionLog | |||
| from services.dataset_service import DatasetService, DocumentService | |||
| from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig | |||
| logger = logging.getLogger(__name__) | |||
| class DocumentResource(Resource): | |||
| def get_document(self, dataset_id: str, document_id: str) -> Document: | |||
| @@ -470,25 +472,11 @@ class DocumentBatchIndexingEstimateApi(DocumentResource): | |||
| return {"tokens": 0, "total_price": 0, "currency": "USD", "total_segments": 0, "preview": []}, 200 | |||
| data_process_rule = documents[0].dataset_process_rule | |||
| data_process_rule_dict = data_process_rule.to_dict() | |||
| info_list = [] | |||
| extract_settings = [] | |||
| for document in documents: | |||
| if document.indexing_status in {"completed", "error"}: | |||
| raise DocumentAlreadyFinishedError() | |||
| data_source_info = document.data_source_info_dict | |||
| # format document files info | |||
| if data_source_info and "upload_file_id" in data_source_info: | |||
| file_id = data_source_info["upload_file_id"] | |||
| info_list.append(file_id) | |||
| # format document notion info | |||
| elif ( | |||
| data_source_info and "notion_workspace_id" in data_source_info and "notion_page_id" in data_source_info | |||
| ): | |||
| pages = [] | |||
| page = {"page_id": data_source_info["notion_page_id"], "type": data_source_info["type"]} | |||
| pages.append(page) | |||
| notion_info = {"workspace_id": data_source_info["notion_workspace_id"], "pages": pages} | |||
| info_list.append(notion_info) | |||
| if document.data_source_type == "upload_file": | |||
| file_id = data_source_info["upload_file_id"] | |||
| @@ -969,7 +957,7 @@ class DocumentRetryApi(DocumentResource): | |||
| raise DocumentAlreadyFinishedError() | |||
| retry_documents.append(document) | |||
| except Exception: | |||
| logging.exception("Failed to retry document, document id: %s", document_id) | |||
| logger.exception("Failed to retry document, document id: %s", document_id) | |||
| continue | |||
| # retry document | |||
| DocumentService.retry_document(dataset_id, retry_documents) | |||
| @@ -23,6 +23,8 @@ from fields.hit_testing_fields import hit_testing_record_fields | |||
| from services.dataset_service import DatasetService | |||
| from services.hit_testing_service import HitTestingService | |||
| logger = logging.getLogger(__name__) | |||
| class DatasetsHitTestingBase: | |||
| @staticmethod | |||
| @@ -81,5 +83,5 @@ class DatasetsHitTestingBase: | |||
| except ValueError as e: | |||
| raise ValueError(str(e)) | |||
| except Exception as e: | |||
| logging.exception("Hit testing failed.") | |||
| logger.exception("Hit testing failed.") | |||
| raise InternalServerError(str(e)) | |||
| @@ -26,6 +26,8 @@ from services.errors.audio import ( | |||
| UnsupportedAudioTypeServiceError, | |||
| ) | |||
| logger = logging.getLogger(__name__) | |||
| class ChatAudioApi(InstalledAppResource): | |||
| def post(self, installed_app): | |||
| @@ -38,7 +40,7 @@ class ChatAudioApi(InstalledAppResource): | |||
| return response | |||
| except services.errors.app_model_config.AppModelConfigBrokenError: | |||
| logging.exception("App model config broken.") | |||
| logger.exception("App model config broken.") | |||
| raise AppUnavailableError() | |||
| except NoAudioUploadedServiceError: | |||
| raise NoAudioUploadedError() | |||
| @@ -59,7 +61,7 @@ class ChatAudioApi(InstalledAppResource): | |||
| except ValueError as e: | |||
| raise e | |||
| except Exception as e: | |||
| logging.exception("internal server error.") | |||
| logger.exception("internal server error.") | |||
| raise InternalServerError() | |||
| @@ -83,7 +85,7 @@ class ChatTextApi(InstalledAppResource): | |||
| response = AudioService.transcript_tts(app_model=app_model, text=text, voice=voice, message_id=message_id) | |||
| return response | |||
| except services.errors.app_model_config.AppModelConfigBrokenError: | |||
| logging.exception("App model config broken.") | |||
| logger.exception("App model config broken.") | |||
| raise AppUnavailableError() | |||
| except NoAudioUploadedServiceError: | |||
| raise NoAudioUploadedError() | |||
| @@ -104,5 +106,5 @@ class ChatTextApi(InstalledAppResource): | |||
| except ValueError as e: | |||
| raise e | |||
| except Exception as e: | |||
| logging.exception("internal server error.") | |||
| logger.exception("internal server error.") | |||
| raise InternalServerError() | |||
| @@ -32,6 +32,8 @@ from models.model import AppMode | |||
| from services.app_generate_service import AppGenerateService | |||
| from services.errors.llm import InvokeRateLimitError | |||
| logger = logging.getLogger(__name__) | |||
| # define completion api for user | |||
| class CompletionApi(InstalledAppResource): | |||
| @@ -65,7 +67,7 @@ class CompletionApi(InstalledAppResource): | |||
| except services.errors.conversation.ConversationCompletedError: | |||
| raise ConversationCompletedError() | |||
| except services.errors.app_model_config.AppModelConfigBrokenError: | |||
| logging.exception("App model config broken.") | |||
| logger.exception("App model config broken.") | |||
| raise AppUnavailableError() | |||
| except ProviderTokenNotInitError as ex: | |||
| raise ProviderNotInitializeError(ex.description) | |||
| @@ -78,7 +80,7 @@ class CompletionApi(InstalledAppResource): | |||
| except ValueError as e: | |||
| raise e | |||
| except Exception: | |||
| logging.exception("internal server error.") | |||
| logger.exception("internal server error.") | |||
| raise InternalServerError() | |||
| @@ -125,7 +127,7 @@ class ChatApi(InstalledAppResource): | |||
| except services.errors.conversation.ConversationCompletedError: | |||
| raise ConversationCompletedError() | |||
| except services.errors.app_model_config.AppModelConfigBrokenError: | |||
| logging.exception("App model config broken.") | |||
| logger.exception("App model config broken.") | |||
| raise AppUnavailableError() | |||
| except ProviderTokenNotInitError as ex: | |||
| raise ProviderNotInitializeError(ex.description) | |||
| @@ -140,7 +142,7 @@ class ChatApi(InstalledAppResource): | |||
| except ValueError as e: | |||
| raise e | |||
| except Exception: | |||
| logging.exception("internal server error.") | |||
| logger.exception("internal server error.") | |||
| raise InternalServerError() | |||
| @@ -35,6 +35,8 @@ from services.errors.message import ( | |||
| ) | |||
| from services.message_service import MessageService | |||
| logger = logging.getLogger(__name__) | |||
| class MessageListApi(InstalledAppResource): | |||
| @marshal_with(message_infinite_scroll_pagination_fields) | |||
| @@ -126,7 +128,7 @@ class MessageMoreLikeThisApi(InstalledAppResource): | |||
| except ValueError as e: | |||
| raise e | |||
| except Exception: | |||
| logging.exception("internal server error.") | |||
| logger.exception("internal server error.") | |||
| raise InternalServerError() | |||
| @@ -158,7 +160,7 @@ class MessageSuggestedQuestionApi(InstalledAppResource): | |||
| except InvokeError as e: | |||
| raise CompletionRequestError(e.description) | |||
| except Exception: | |||
| logging.exception("internal server error.") | |||
| logger.exception("internal server error.") | |||
| raise InternalServerError() | |||
| return {"data": questions} | |||
| @@ -43,7 +43,7 @@ class InstalledAppWorkflowRunApi(InstalledAppResource): | |||
| parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") | |||
| parser.add_argument("files", type=list, required=False, location="json") | |||
| args = parser.parse_args() | |||
| assert current_user is not None | |||
| try: | |||
| response = AppGenerateService.generate( | |||
| app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=True | |||
| @@ -63,7 +63,7 @@ class InstalledAppWorkflowRunApi(InstalledAppResource): | |||
| except ValueError as e: | |||
| raise e | |||
| except Exception: | |||
| logging.exception("internal server error.") | |||
| logger.exception("internal server error.") | |||
| raise InternalServerError() | |||
| @@ -76,6 +76,7 @@ class InstalledAppWorkflowTaskStopApi(InstalledAppResource): | |||
| app_mode = AppMode.value_of(app_model.mode) | |||
| if app_mode != AppMode.WORKFLOW: | |||
| raise NotWorkflowAppError() | |||
| assert current_user is not None | |||
| AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id) | |||
| @@ -9,6 +9,8 @@ from configs import dify_config | |||
| from . import api | |||
| logger = logging.getLogger(__name__) | |||
| class VersionApi(Resource): | |||
| def get(self): | |||
| @@ -34,7 +36,7 @@ class VersionApi(Resource): | |||
| try: | |||
| response = requests.get(check_update_url, {"current_version": args.get("current_version")}, timeout=(3, 10)) | |||
| except Exception as error: | |||
| logging.warning("Check update version error: %s.", str(error)) | |||
| logger.warning("Check update version error: %s.", str(error)) | |||
| result["version"] = args.get("current_version") | |||
| return result | |||
| @@ -55,7 +57,7 @@ def _has_new_version(*, latest_version: str, current_version: str) -> bool: | |||
| # Compare versions | |||
| return latest > current | |||
| except version.InvalidVersion: | |||
| logging.warning("Invalid version format: latest=%s, current=%s", latest_version, current_version) | |||
| logger.warning("Invalid version format: latest=%s, current=%s", latest_version, current_version) | |||
| return False | |||
| @@ -6,7 +6,7 @@ from controllers.console.wraps import account_initialization_required, setup_req | |||
| from core.model_runtime.entities.model_entities import ModelType | |||
| from core.model_runtime.errors.validate import CredentialsValidateFailedError | |||
| from libs.login import current_user, login_required | |||
| from models.account import TenantAccountRole | |||
| from models.account import Account, TenantAccountRole | |||
| from services.model_load_balancing_service import ModelLoadBalancingService | |||
| @@ -15,10 +15,12 @@ class LoadBalancingCredentialsValidateApi(Resource): | |||
| @login_required | |||
| @account_initialization_required | |||
| def post(self, provider: str): | |||
| assert isinstance(current_user, Account) | |||
| if not TenantAccountRole.is_privileged_role(current_user.current_role): | |||
| raise Forbidden() | |||
| tenant_id = current_user.current_tenant_id | |||
| assert tenant_id is not None | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("model", type=str, required=True, nullable=False, location="json") | |||
| @@ -64,10 +66,12 @@ class LoadBalancingConfigCredentialsValidateApi(Resource): | |||
| @login_required | |||
| @account_initialization_required | |||
| def post(self, provider: str, config_id: str): | |||
| assert isinstance(current_user, Account) | |||
| if not TenantAccountRole.is_privileged_role(current_user.current_role): | |||
| raise Forbidden() | |||
| tenant_id = current_user.current_tenant_id | |||
| assert tenant_id is not None | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("model", type=str, required=True, nullable=False, location="json") | |||
| @@ -54,7 +54,7 @@ class MemberInviteEmailApi(Resource): | |||
| @cloud_edition_billing_resource_check("members") | |||
| def post(self): | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("emails", type=str, required=True, location="json", action="append") | |||
| parser.add_argument("emails", type=list, required=True, location="json") | |||
| parser.add_argument("role", type=str, required=True, default="admin", location="json") | |||
| parser.add_argument("language", type=str, required=False, location="json") | |||
| args = parser.parse_args() | |||
| @@ -10,6 +10,7 @@ from controllers.console.wraps import account_initialization_required, setup_req | |||
| from core.model_runtime.entities.model_entities import ModelType | |||
| from core.model_runtime.errors.validate import CredentialsValidateFailedError | |||
| from core.model_runtime.utils.encoders import jsonable_encoder | |||
| from libs.helper import StrLen, uuid_value | |||
| from libs.login import login_required | |||
| from services.billing_service import BillingService | |||
| from services.model_provider_service import ModelProviderService | |||
| @@ -45,67 +46,71 @@ class ModelProviderCredentialApi(Resource): | |||
| @account_initialization_required | |||
| def get(self, provider: str): | |||
| tenant_id = current_user.current_tenant_id | |||
| # if credential_id is not provided, return current used credential | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("credential_id", type=uuid_value, required=False, nullable=True, location="args") | |||
| args = parser.parse_args() | |||
| model_provider_service = ModelProviderService() | |||
| credentials = model_provider_service.get_provider_credentials(tenant_id=tenant_id, provider=provider) | |||
| credentials = model_provider_service.get_provider_credential( | |||
| tenant_id=tenant_id, provider=provider, credential_id=args.get("credential_id") | |||
| ) | |||
| return {"credentials": credentials} | |||
| class ModelProviderValidateApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| def post(self, provider: str): | |||
| if not current_user.is_admin_or_owner: | |||
| raise Forbidden() | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") | |||
| parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json") | |||
| args = parser.parse_args() | |||
| tenant_id = current_user.current_tenant_id | |||
| model_provider_service = ModelProviderService() | |||
| result = True | |||
| error = "" | |||
| try: | |||
| model_provider_service.provider_credentials_validate( | |||
| tenant_id=tenant_id, provider=provider, credentials=args["credentials"] | |||
| model_provider_service.create_provider_credential( | |||
| tenant_id=current_user.current_tenant_id, | |||
| provider=provider, | |||
| credentials=args["credentials"], | |||
| credential_name=args["name"], | |||
| ) | |||
| except CredentialsValidateFailedError as ex: | |||
| result = False | |||
| error = str(ex) | |||
| response = {"result": "success" if result else "error"} | |||
| if not result: | |||
| response["error"] = error or "Unknown error" | |||
| return response | |||
| raise ValueError(str(ex)) | |||
| return {"result": "success"}, 201 | |||
| class ModelProviderApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| def post(self, provider: str): | |||
| def put(self, provider: str): | |||
| if not current_user.is_admin_or_owner: | |||
| raise Forbidden() | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json") | |||
| parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") | |||
| parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json") | |||
| args = parser.parse_args() | |||
| model_provider_service = ModelProviderService() | |||
| try: | |||
| model_provider_service.save_provider_credentials( | |||
| tenant_id=current_user.current_tenant_id, provider=provider, credentials=args["credentials"] | |||
| model_provider_service.update_provider_credential( | |||
| tenant_id=current_user.current_tenant_id, | |||
| provider=provider, | |||
| credentials=args["credentials"], | |||
| credential_id=args["credential_id"], | |||
| credential_name=args["name"], | |||
| ) | |||
| except CredentialsValidateFailedError as ex: | |||
| raise ValueError(str(ex)) | |||
| return {"result": "success"}, 201 | |||
| return {"result": "success"} | |||
| @setup_required | |||
| @login_required | |||
| @@ -113,13 +118,70 @@ class ModelProviderApi(Resource): | |||
| def delete(self, provider: str): | |||
| if not current_user.is_admin_or_owner: | |||
| raise Forbidden() | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json") | |||
| args = parser.parse_args() | |||
| model_provider_service = ModelProviderService() | |||
| model_provider_service.remove_provider_credentials(tenant_id=current_user.current_tenant_id, provider=provider) | |||
| model_provider_service.remove_provider_credential( | |||
| tenant_id=current_user.current_tenant_id, provider=provider, credential_id=args["credential_id"] | |||
| ) | |||
| return {"result": "success"}, 204 | |||
| class ModelProviderCredentialSwitchApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| def post(self, provider: str): | |||
| if not current_user.is_admin_or_owner: | |||
| raise Forbidden() | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json") | |||
| args = parser.parse_args() | |||
| service = ModelProviderService() | |||
| service.switch_active_provider_credential( | |||
| tenant_id=current_user.current_tenant_id, | |||
| provider=provider, | |||
| credential_id=args["credential_id"], | |||
| ) | |||
| return {"result": "success"} | |||
| class ModelProviderValidateApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| def post(self, provider: str): | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") | |||
| args = parser.parse_args() | |||
| tenant_id = current_user.current_tenant_id | |||
| model_provider_service = ModelProviderService() | |||
| result = True | |||
| error = "" | |||
| try: | |||
| model_provider_service.validate_provider_credentials( | |||
| tenant_id=tenant_id, provider=provider, credentials=args["credentials"] | |||
| ) | |||
| except CredentialsValidateFailedError as ex: | |||
| result = False | |||
| error = str(ex) | |||
| response = {"result": "success" if result else "error"} | |||
| if not result: | |||
| response["error"] = error or "Unknown error" | |||
| return response | |||
| class ModelProviderIconApi(Resource): | |||
| """ | |||
| Get model provider icon | |||
| @@ -187,8 +249,10 @@ class ModelProviderPaymentCheckoutUrlApi(Resource): | |||
| api.add_resource(ModelProviderListApi, "/workspaces/current/model-providers") | |||
| api.add_resource(ModelProviderCredentialApi, "/workspaces/current/model-providers/<path:provider>/credentials") | |||
| api.add_resource( | |||
| ModelProviderCredentialSwitchApi, "/workspaces/current/model-providers/<path:provider>/credentials/switch" | |||
| ) | |||
| api.add_resource(ModelProviderValidateApi, "/workspaces/current/model-providers/<path:provider>/credentials/validate") | |||
| api.add_resource(ModelProviderApi, "/workspaces/current/model-providers/<path:provider>") | |||
| api.add_resource( | |||
| PreferredProviderTypeUpdateApi, "/workspaces/current/model-providers/<path:provider>/preferred-provider-type" | |||
| @@ -9,10 +9,13 @@ from controllers.console.wraps import account_initialization_required, setup_req | |||
| from core.model_runtime.entities.model_entities import ModelType | |||
| from core.model_runtime.errors.validate import CredentialsValidateFailedError | |||
| from core.model_runtime.utils.encoders import jsonable_encoder | |||
| from libs.helper import StrLen, uuid_value | |||
| from libs.login import login_required | |||
| from services.model_load_balancing_service import ModelLoadBalancingService | |||
| from services.model_provider_service import ModelProviderService | |||
| logger = logging.getLogger(__name__) | |||
| class DefaultModelApi(Resource): | |||
| @setup_required | |||
| @@ -72,7 +75,7 @@ class DefaultModelApi(Resource): | |||
| model=model_setting["model"], | |||
| ) | |||
| except Exception as ex: | |||
| logging.exception( | |||
| logger.exception( | |||
| "Failed to update default model, model type: %s, model: %s", | |||
| model_setting["model_type"], | |||
| model_setting.get("model"), | |||
| @@ -98,6 +101,7 @@ class ModelProviderModelApi(Resource): | |||
| @login_required | |||
| @account_initialization_required | |||
| def post(self, provider: str): | |||
| # To save the model's load balance configs | |||
| if not current_user.is_admin_or_owner: | |||
| raise Forbidden() | |||
| @@ -113,22 +117,26 @@ class ModelProviderModelApi(Resource): | |||
| choices=[mt.value for mt in ModelType], | |||
| location="json", | |||
| ) | |||
| parser.add_argument("credentials", type=dict, required=False, nullable=True, location="json") | |||
| parser.add_argument("load_balancing", type=dict, required=False, nullable=True, location="json") | |||
| parser.add_argument("config_from", type=str, required=False, nullable=True, location="json") | |||
| parser.add_argument("credential_id", type=uuid_value, required=False, nullable=True, location="json") | |||
| args = parser.parse_args() | |||
| model_load_balancing_service = ModelLoadBalancingService() | |||
| if args.get("config_from", "") == "custom-model": | |||
| if not args.get("credential_id"): | |||
| raise ValueError("credential_id is required when configuring a custom-model") | |||
| service = ModelProviderService() | |||
| service.switch_active_custom_model_credential( | |||
| tenant_id=current_user.current_tenant_id, | |||
| provider=provider, | |||
| model_type=args["model_type"], | |||
| model=args["model"], | |||
| credential_id=args["credential_id"], | |||
| ) | |||
| if ( | |||
| "load_balancing" in args | |||
| and args["load_balancing"] | |||
| and "enabled" in args["load_balancing"] | |||
| and args["load_balancing"]["enabled"] | |||
| ): | |||
| if "configs" not in args["load_balancing"]: | |||
| raise ValueError("invalid load balancing configs") | |||
| model_load_balancing_service = ModelLoadBalancingService() | |||
| if "load_balancing" in args and args["load_balancing"] and "configs" in args["load_balancing"]: | |||
| # save load balancing configs | |||
| model_load_balancing_service.update_load_balancing_configs( | |||
| tenant_id=tenant_id, | |||
| @@ -136,37 +144,17 @@ class ModelProviderModelApi(Resource): | |||
| model=args["model"], | |||
| model_type=args["model_type"], | |||
| configs=args["load_balancing"]["configs"], | |||
| config_from=args.get("config_from", ""), | |||
| ) | |||
| # enable load balancing | |||
| model_load_balancing_service.enable_model_load_balancing( | |||
| tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"] | |||
| ) | |||
| else: | |||
| # disable load balancing | |||
| model_load_balancing_service.disable_model_load_balancing( | |||
| tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"] | |||
| ) | |||
| if args.get("config_from", "") != "predefined-model": | |||
| model_provider_service = ModelProviderService() | |||
| try: | |||
| model_provider_service.save_model_credentials( | |||
| tenant_id=tenant_id, | |||
| provider=provider, | |||
| model=args["model"], | |||
| model_type=args["model_type"], | |||
| credentials=args["credentials"], | |||
| ) | |||
| except CredentialsValidateFailedError as ex: | |||
| logging.exception( | |||
| "Failed to save model credentials, tenant_id: %s, model: %s, model_type: %s", | |||
| tenant_id, | |||
| args.get("model"), | |||
| args.get("model_type"), | |||
| ) | |||
| raise ValueError(str(ex)) | |||
| if args.get("load_balancing", {}).get("enabled"): | |||
| model_load_balancing_service.enable_model_load_balancing( | |||
| tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"] | |||
| ) | |||
| else: | |||
| model_load_balancing_service.disable_model_load_balancing( | |||
| tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"] | |||
| ) | |||
| return {"result": "success"}, 200 | |||
| @@ -192,7 +180,7 @@ class ModelProviderModelApi(Resource): | |||
| args = parser.parse_args() | |||
| model_provider_service = ModelProviderService() | |||
| model_provider_service.remove_model_credentials( | |||
| model_provider_service.remove_model( | |||
| tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"] | |||
| ) | |||
| @@ -216,11 +204,17 @@ class ModelProviderModelCredentialApi(Resource): | |||
| choices=[mt.value for mt in ModelType], | |||
| location="args", | |||
| ) | |||
| parser.add_argument("config_from", type=str, required=False, nullable=True, location="args") | |||
| parser.add_argument("credential_id", type=uuid_value, required=False, nullable=True, location="args") | |||
| args = parser.parse_args() | |||
| model_provider_service = ModelProviderService() | |||
| credentials = model_provider_service.get_model_credentials( | |||
| tenant_id=tenant_id, provider=provider, model_type=args["model_type"], model=args["model"] | |||
| current_credential = model_provider_service.get_model_credential( | |||
| tenant_id=tenant_id, | |||
| provider=provider, | |||
| model_type=args["model_type"], | |||
| model=args["model"], | |||
| credential_id=args.get("credential_id"), | |||
| ) | |||
| model_load_balancing_service = ModelLoadBalancingService() | |||
| @@ -228,10 +222,173 @@ class ModelProviderModelCredentialApi(Resource): | |||
| tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"] | |||
| ) | |||
| return { | |||
| "credentials": credentials, | |||
| "load_balancing": {"enabled": is_load_balancing_enabled, "configs": load_balancing_configs}, | |||
| } | |||
| if args.get("config_from", "") == "predefined-model": | |||
| available_credentials = model_provider_service.provider_manager.get_provider_available_credentials( | |||
| tenant_id=tenant_id, provider_name=provider | |||
| ) | |||
| else: | |||
| model_type = ModelType.value_of(args["model_type"]).to_origin_model_type() | |||
| available_credentials = model_provider_service.provider_manager.get_provider_model_available_credentials( | |||
| tenant_id=tenant_id, provider_name=provider, model_type=model_type, model_name=args["model"] | |||
| ) | |||
| return jsonable_encoder( | |||
| { | |||
| "credentials": current_credential.get("credentials") if current_credential else {}, | |||
| "current_credential_id": current_credential.get("current_credential_id") | |||
| if current_credential | |||
| else None, | |||
| "current_credential_name": current_credential.get("current_credential_name") | |||
| if current_credential | |||
| else None, | |||
| "load_balancing": {"enabled": is_load_balancing_enabled, "configs": load_balancing_configs}, | |||
| "available_credentials": available_credentials, | |||
| } | |||
| ) | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| def post(self, provider: str): | |||
| if not current_user.is_admin_or_owner: | |||
| raise Forbidden() | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("model", type=str, required=True, nullable=False, location="json") | |||
| parser.add_argument( | |||
| "model_type", | |||
| type=str, | |||
| required=True, | |||
| nullable=False, | |||
| choices=[mt.value for mt in ModelType], | |||
| location="json", | |||
| ) | |||
| parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json") | |||
| parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") | |||
| args = parser.parse_args() | |||
| tenant_id = current_user.current_tenant_id | |||
| model_provider_service = ModelProviderService() | |||
| try: | |||
| model_provider_service.create_model_credential( | |||
| tenant_id=tenant_id, | |||
| provider=provider, | |||
| model=args["model"], | |||
| model_type=args["model_type"], | |||
| credentials=args["credentials"], | |||
| credential_name=args["name"], | |||
| ) | |||
| except CredentialsValidateFailedError as ex: | |||
| logger.exception( | |||
| "Failed to save model credentials, tenant_id: %s, model: %s, model_type: %s", | |||
| tenant_id, | |||
| args.get("model"), | |||
| args.get("model_type"), | |||
| ) | |||
| raise ValueError(str(ex)) | |||
| return {"result": "success"}, 201 | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| def put(self, provider: str): | |||
| if not current_user.is_admin_or_owner: | |||
| raise Forbidden() | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("model", type=str, required=True, nullable=False, location="json") | |||
| parser.add_argument( | |||
| "model_type", | |||
| type=str, | |||
| required=True, | |||
| nullable=False, | |||
| choices=[mt.value for mt in ModelType], | |||
| location="json", | |||
| ) | |||
| parser.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json") | |||
| parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") | |||
| parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json") | |||
| args = parser.parse_args() | |||
| model_provider_service = ModelProviderService() | |||
| try: | |||
| model_provider_service.update_model_credential( | |||
| tenant_id=current_user.current_tenant_id, | |||
| provider=provider, | |||
| model_type=args["model_type"], | |||
| model=args["model"], | |||
| credentials=args["credentials"], | |||
| credential_id=args["credential_id"], | |||
| credential_name=args["name"], | |||
| ) | |||
| except CredentialsValidateFailedError as ex: | |||
| raise ValueError(str(ex)) | |||
| return {"result": "success"} | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| def delete(self, provider: str): | |||
| if not current_user.is_admin_or_owner: | |||
| raise Forbidden() | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("model", type=str, required=True, nullable=False, location="json") | |||
| parser.add_argument( | |||
| "model_type", | |||
| type=str, | |||
| required=True, | |||
| nullable=False, | |||
| choices=[mt.value for mt in ModelType], | |||
| location="json", | |||
| ) | |||
| parser.add_argument("credential_id", type=uuid_value, required=True, nullable=False, location="json") | |||
| args = parser.parse_args() | |||
| model_provider_service = ModelProviderService() | |||
| model_provider_service.remove_model_credential( | |||
| tenant_id=current_user.current_tenant_id, | |||
| provider=provider, | |||
| model_type=args["model_type"], | |||
| model=args["model"], | |||
| credential_id=args["credential_id"], | |||
| ) | |||
| return {"result": "success"}, 204 | |||
| class ModelProviderModelCredentialSwitchApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| def post(self, provider: str): | |||
| if not current_user.is_admin_or_owner: | |||
| raise Forbidden() | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("model", type=str, required=True, nullable=False, location="json") | |||
| parser.add_argument( | |||
| "model_type", | |||
| type=str, | |||
| required=True, | |||
| nullable=False, | |||
| choices=[mt.value for mt in ModelType], | |||
| location="json", | |||
| ) | |||
| parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json") | |||
| args = parser.parse_args() | |||
| service = ModelProviderService() | |||
| service.add_model_credential_to_model_list( | |||
| tenant_id=current_user.current_tenant_id, | |||
| provider=provider, | |||
| model_type=args["model_type"], | |||
| model=args["model"], | |||
| credential_id=args["credential_id"], | |||
| ) | |||
| return {"result": "success"} | |||
| class ModelProviderModelEnableApi(Resource): | |||
| @@ -314,7 +471,7 @@ class ModelProviderModelValidateApi(Resource): | |||
| error = "" | |||
| try: | |||
| model_provider_service.model_credentials_validate( | |||
| model_provider_service.validate_model_credentials( | |||
| tenant_id=tenant_id, | |||
| provider=provider, | |||
| model=args["model"], | |||
| @@ -379,6 +536,10 @@ api.add_resource( | |||
| api.add_resource( | |||
| ModelProviderModelCredentialApi, "/workspaces/current/model-providers/<path:provider>/models/credentials" | |||
| ) | |||
| api.add_resource( | |||
| ModelProviderModelCredentialSwitchApi, | |||
| "/workspaces/current/model-providers/<path:provider>/models/credentials/switch", | |||
| ) | |||
| api.add_resource( | |||
| ModelProviderModelValidateApi, "/workspaces/current/model-providers/<path:provider>/models/credentials/validate" | |||
| ) | |||
| @@ -31,6 +31,9 @@ from services.feature_service import FeatureService | |||
| from services.file_service import FileService | |||
| from services.workspace_service import WorkspaceService | |||
| logger = logging.getLogger(__name__) | |||
| provider_fields = { | |||
| "provider_name": fields.String, | |||
| "provider_type": fields.String, | |||
| @@ -120,7 +123,7 @@ class TenantApi(Resource): | |||
| @marshal_with(tenant_fields) | |||
| def get(self): | |||
| if request.path == "/info": | |||
| logging.warning("Deprecated URL /info was used.") | |||
| logger.warning("Deprecated URL /info was used.") | |||
| tenant = current_user.current_tenant | |||
| @@ -1,10 +1,23 @@ | |||
| from flask import Blueprint | |||
| from flask_restx import Namespace | |||
| from libs.external_api import ExternalApi | |||
| bp = Blueprint("inner_api", __name__, url_prefix="/inner/api") | |||
| api = ExternalApi(bp) | |||
| api = ExternalApi( | |||
| bp, | |||
| version="1.0", | |||
| title="Inner API", | |||
| description="Internal APIs for enterprise features, billing, and plugin communication", | |||
| doc="/docs", # Enable Swagger UI at /inner/api/docs | |||
| ) | |||
| # Create namespace | |||
| inner_api_ns = Namespace("inner_api", description="Internal API operations", path="/") | |||
| from . import mail | |||
| from .plugin import plugin | |||
| from .workspace import workspace | |||
| api.add_namespace(inner_api_ns) | |||
| @@ -1,7 +1,7 @@ | |||
| from flask_restx import Resource, reqparse | |||
| from controllers.console.wraps import setup_required | |||
| from controllers.inner_api import api | |||
| from controllers.inner_api import inner_api_ns | |||
| from controllers.inner_api.wraps import billing_inner_api_only, enterprise_inner_api_only | |||
| from tasks.mail_inner_task import send_inner_email_task | |||
| @@ -26,13 +26,45 @@ class BaseMail(Resource): | |||
| return {"message": "success"}, 200 | |||
| @inner_api_ns.route("/enterprise/mail") | |||
| class EnterpriseMail(BaseMail): | |||
| method_decorators = [setup_required, enterprise_inner_api_only] | |||
| @inner_api_ns.doc("send_enterprise_mail") | |||
| @inner_api_ns.doc(description="Send internal email for enterprise features") | |||
| @inner_api_ns.expect(_mail_parser) | |||
| @inner_api_ns.doc( | |||
| responses={200: "Email sent successfully", 401: "Unauthorized - invalid API key", 404: "Service not available"} | |||
| ) | |||
| def post(self): | |||
| """Send internal email for enterprise features. | |||
| This endpoint allows sending internal emails for enterprise-specific | |||
| notifications and communications. | |||
| Returns: | |||
| dict: Success message with status code 200 | |||
| """ | |||
| return super().post() | |||
| @inner_api_ns.route("/billing/mail") | |||
| class BillingMail(BaseMail): | |||
| method_decorators = [setup_required, billing_inner_api_only] | |||
| @inner_api_ns.doc("send_billing_mail") | |||
| @inner_api_ns.doc(description="Send internal email for billing notifications") | |||
| @inner_api_ns.expect(_mail_parser) | |||
| @inner_api_ns.doc( | |||
| responses={200: "Email sent successfully", 401: "Unauthorized - invalid API key", 404: "Service not available"} | |||
| ) | |||
| def post(self): | |||
| """Send internal email for billing notifications. | |||
| This endpoint allows sending internal emails for billing-related | |||
| notifications and alerts. | |||
| api.add_resource(EnterpriseMail, "/enterprise/mail") | |||
| api.add_resource(BillingMail, "/billing/mail") | |||
| Returns: | |||
| dict: Success message with status code 200 | |||
| """ | |||
| return super().post() | |||
| @@ -1,7 +1,7 @@ | |||
| from flask_restx import Resource | |||
| from controllers.console.wraps import setup_required | |||
| from controllers.inner_api import api | |||
| from controllers.inner_api import inner_api_ns | |||
| from controllers.inner_api.plugin.wraps import get_user_tenant, plugin_data | |||
| from controllers.inner_api.wraps import plugin_inner_api_only | |||
| from core.file.helpers import get_signed_file_url_for_plugin | |||
| @@ -35,11 +35,21 @@ from models.account import Account, Tenant | |||
| from models.model import EndUser | |||
| @inner_api_ns.route("/invoke/llm") | |||
| class PluginInvokeLLMApi(Resource): | |||
| @setup_required | |||
| @plugin_inner_api_only | |||
| @get_user_tenant | |||
| @plugin_data(payload_type=RequestInvokeLLM) | |||
| @inner_api_ns.doc("plugin_invoke_llm") | |||
| @inner_api_ns.doc(description="Invoke LLM models through plugin interface") | |||
| @inner_api_ns.doc( | |||
| responses={ | |||
| 200: "LLM invocation successful (streaming response)", | |||
| 401: "Unauthorized - invalid API key", | |||
| 404: "Service not available", | |||
| } | |||
| ) | |||
| def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeLLM): | |||
| def generator(): | |||
| response = PluginModelBackwardsInvocation.invoke_llm(user_model.id, tenant_model, payload) | |||
| @@ -48,11 +58,21 @@ class PluginInvokeLLMApi(Resource): | |||
| return length_prefixed_response(0xF, generator()) | |||
| @inner_api_ns.route("/invoke/llm/structured-output") | |||
| class PluginInvokeLLMWithStructuredOutputApi(Resource): | |||
| @setup_required | |||
| @plugin_inner_api_only | |||
| @get_user_tenant | |||
| @plugin_data(payload_type=RequestInvokeLLMWithStructuredOutput) | |||
| @inner_api_ns.doc("plugin_invoke_llm_structured") | |||
| @inner_api_ns.doc(description="Invoke LLM models with structured output through plugin interface") | |||
| @inner_api_ns.doc( | |||
| responses={ | |||
| 200: "LLM structured output invocation successful (streaming response)", | |||
| 401: "Unauthorized - invalid API key", | |||
| 404: "Service not available", | |||
| } | |||
| ) | |||
| def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeLLMWithStructuredOutput): | |||
| def generator(): | |||
| response = PluginModelBackwardsInvocation.invoke_llm_with_structured_output( | |||
| @@ -63,11 +83,21 @@ class PluginInvokeLLMWithStructuredOutputApi(Resource): | |||
| return length_prefixed_response(0xF, generator()) | |||
| @inner_api_ns.route("/invoke/text-embedding") | |||
| class PluginInvokeTextEmbeddingApi(Resource): | |||
| @setup_required | |||
| @plugin_inner_api_only | |||
| @get_user_tenant | |||
| @plugin_data(payload_type=RequestInvokeTextEmbedding) | |||
| @inner_api_ns.doc("plugin_invoke_text_embedding") | |||
| @inner_api_ns.doc(description="Invoke text embedding models through plugin interface") | |||
| @inner_api_ns.doc( | |||
| responses={ | |||
| 200: "Text embedding successful", | |||
| 401: "Unauthorized - invalid API key", | |||
| 404: "Service not available", | |||
| } | |||
| ) | |||
| def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeTextEmbedding): | |||
| try: | |||
| return jsonable_encoder( | |||
| @@ -83,11 +113,17 @@ class PluginInvokeTextEmbeddingApi(Resource): | |||
| return jsonable_encoder(BaseBackwardsInvocationResponse(error=str(e))) | |||
| @inner_api_ns.route("/invoke/rerank") | |||
| class PluginInvokeRerankApi(Resource): | |||
| @setup_required | |||
| @plugin_inner_api_only | |||
| @get_user_tenant | |||
| @plugin_data(payload_type=RequestInvokeRerank) | |||
| @inner_api_ns.doc("plugin_invoke_rerank") | |||
| @inner_api_ns.doc(description="Invoke rerank models through plugin interface") | |||
| @inner_api_ns.doc( | |||
| responses={200: "Rerank successful", 401: "Unauthorized - invalid API key", 404: "Service not available"} | |||
| ) | |||
| def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeRerank): | |||
| try: | |||
| return jsonable_encoder( | |||
| @@ -103,11 +139,21 @@ class PluginInvokeRerankApi(Resource): | |||
| return jsonable_encoder(BaseBackwardsInvocationResponse(error=str(e))) | |||
| @inner_api_ns.route("/invoke/tts") | |||
| class PluginInvokeTTSApi(Resource): | |||
| @setup_required | |||
| @plugin_inner_api_only | |||
| @get_user_tenant | |||
| @plugin_data(payload_type=RequestInvokeTTS) | |||
| @inner_api_ns.doc("plugin_invoke_tts") | |||
| @inner_api_ns.doc(description="Invoke text-to-speech models through plugin interface") | |||
| @inner_api_ns.doc( | |||
| responses={ | |||
| 200: "TTS invocation successful (streaming response)", | |||
| 401: "Unauthorized - invalid API key", | |||
| 404: "Service not available", | |||
| } | |||
| ) | |||
| def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeTTS): | |||
| def generator(): | |||
| response = PluginModelBackwardsInvocation.invoke_tts( | |||
| @@ -120,11 +166,17 @@ class PluginInvokeTTSApi(Resource): | |||
| return length_prefixed_response(0xF, generator()) | |||
| @inner_api_ns.route("/invoke/speech2text") | |||
| class PluginInvokeSpeech2TextApi(Resource): | |||
| @setup_required | |||
| @plugin_inner_api_only | |||
| @get_user_tenant | |||
| @plugin_data(payload_type=RequestInvokeSpeech2Text) | |||
| @inner_api_ns.doc("plugin_invoke_speech2text") | |||
| @inner_api_ns.doc(description="Invoke speech-to-text models through plugin interface") | |||
| @inner_api_ns.doc( | |||
| responses={200: "Speech2Text successful", 401: "Unauthorized - invalid API key", 404: "Service not available"} | |||
| ) | |||
| def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeSpeech2Text): | |||
| try: | |||
| return jsonable_encoder( | |||
| @@ -140,11 +192,17 @@ class PluginInvokeSpeech2TextApi(Resource): | |||
| return jsonable_encoder(BaseBackwardsInvocationResponse(error=str(e))) | |||
| @inner_api_ns.route("/invoke/moderation") | |||
| class PluginInvokeModerationApi(Resource): | |||
| @setup_required | |||
| @plugin_inner_api_only | |||
| @get_user_tenant | |||
| @plugin_data(payload_type=RequestInvokeModeration) | |||
| @inner_api_ns.doc("plugin_invoke_moderation") | |||
| @inner_api_ns.doc(description="Invoke moderation models through plugin interface") | |||
| @inner_api_ns.doc( | |||
| responses={200: "Moderation successful", 401: "Unauthorized - invalid API key", 404: "Service not available"} | |||
| ) | |||
| def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeModeration): | |||
| try: | |||
| return jsonable_encoder( | |||
| @@ -160,11 +218,21 @@ class PluginInvokeModerationApi(Resource): | |||
| return jsonable_encoder(BaseBackwardsInvocationResponse(error=str(e))) | |||
| @inner_api_ns.route("/invoke/tool") | |||
| class PluginInvokeToolApi(Resource): | |||
| @setup_required | |||
| @plugin_inner_api_only | |||
| @get_user_tenant | |||
| @plugin_data(payload_type=RequestInvokeTool) | |||
| @inner_api_ns.doc("plugin_invoke_tool") | |||
| @inner_api_ns.doc(description="Invoke tools through plugin interface") | |||
| @inner_api_ns.doc( | |||
| responses={ | |||
| 200: "Tool invocation successful (streaming response)", | |||
| 401: "Unauthorized - invalid API key", | |||
| 404: "Service not available", | |||
| } | |||
| ) | |||
| def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeTool): | |||
| def generator(): | |||
| return PluginToolBackwardsInvocation.convert_to_event_stream( | |||
| @@ -182,11 +250,21 @@ class PluginInvokeToolApi(Resource): | |||
| return length_prefixed_response(0xF, generator()) | |||
| @inner_api_ns.route("/invoke/parameter-extractor") | |||
| class PluginInvokeParameterExtractorNodeApi(Resource): | |||
| @setup_required | |||
| @plugin_inner_api_only | |||
| @get_user_tenant | |||
| @plugin_data(payload_type=RequestInvokeParameterExtractorNode) | |||
| @inner_api_ns.doc("plugin_invoke_parameter_extractor") | |||
| @inner_api_ns.doc(description="Invoke parameter extractor node through plugin interface") | |||
| @inner_api_ns.doc( | |||
| responses={ | |||
| 200: "Parameter extraction successful", | |||
| 401: "Unauthorized - invalid API key", | |||
| 404: "Service not available", | |||
| } | |||
| ) | |||
| def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeParameterExtractorNode): | |||
| try: | |||
| return jsonable_encoder( | |||
| @@ -205,11 +283,21 @@ class PluginInvokeParameterExtractorNodeApi(Resource): | |||
| return jsonable_encoder(BaseBackwardsInvocationResponse(error=str(e))) | |||
| @inner_api_ns.route("/invoke/question-classifier") | |||
| class PluginInvokeQuestionClassifierNodeApi(Resource): | |||
| @setup_required | |||
| @plugin_inner_api_only | |||
| @get_user_tenant | |||
| @plugin_data(payload_type=RequestInvokeQuestionClassifierNode) | |||
| @inner_api_ns.doc("plugin_invoke_question_classifier") | |||
| @inner_api_ns.doc(description="Invoke question classifier node through plugin interface") | |||
| @inner_api_ns.doc( | |||
| responses={ | |||
| 200: "Question classification successful", | |||
| 401: "Unauthorized - invalid API key", | |||
| 404: "Service not available", | |||
| } | |||
| ) | |||
| def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeQuestionClassifierNode): | |||
| try: | |||
| return jsonable_encoder( | |||
| @@ -228,11 +316,21 @@ class PluginInvokeQuestionClassifierNodeApi(Resource): | |||
| return jsonable_encoder(BaseBackwardsInvocationResponse(error=str(e))) | |||
| @inner_api_ns.route("/invoke/app") | |||
| class PluginInvokeAppApi(Resource): | |||
| @setup_required | |||
| @plugin_inner_api_only | |||
| @get_user_tenant | |||
| @plugin_data(payload_type=RequestInvokeApp) | |||
| @inner_api_ns.doc("plugin_invoke_app") | |||
| @inner_api_ns.doc(description="Invoke application through plugin interface") | |||
| @inner_api_ns.doc( | |||
| responses={ | |||
| 200: "App invocation successful (streaming response)", | |||
| 401: "Unauthorized - invalid API key", | |||
| 404: "Service not available", | |||
| } | |||
| ) | |||
| def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeApp): | |||
| response = PluginAppBackwardsInvocation.invoke_app( | |||
| app_id=payload.app_id, | |||
| @@ -248,11 +346,21 @@ class PluginInvokeAppApi(Resource): | |||
| return length_prefixed_response(0xF, PluginAppBackwardsInvocation.convert_to_event_stream(response)) | |||
| @inner_api_ns.route("/invoke/encrypt") | |||
| class PluginInvokeEncryptApi(Resource): | |||
| @setup_required | |||
| @plugin_inner_api_only | |||
| @get_user_tenant | |||
| @plugin_data(payload_type=RequestInvokeEncrypt) | |||
| @inner_api_ns.doc("plugin_invoke_encrypt") | |||
| @inner_api_ns.doc(description="Encrypt or decrypt data through plugin interface") | |||
| @inner_api_ns.doc( | |||
| responses={ | |||
| 200: "Encryption/decryption successful", | |||
| 401: "Unauthorized - invalid API key", | |||
| 404: "Service not available", | |||
| } | |||
| ) | |||
| def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeEncrypt): | |||
| """ | |||
| encrypt or decrypt data | |||
| @@ -265,11 +373,21 @@ class PluginInvokeEncryptApi(Resource): | |||
| return BaseBackwardsInvocationResponse(error=str(e)).model_dump() | |||
| @inner_api_ns.route("/invoke/summary") | |||
| class PluginInvokeSummaryApi(Resource): | |||
| @setup_required | |||
| @plugin_inner_api_only | |||
| @get_user_tenant | |||
| @plugin_data(payload_type=RequestInvokeSummary) | |||
| @inner_api_ns.doc("plugin_invoke_summary") | |||
| @inner_api_ns.doc(description="Invoke summary functionality through plugin interface") | |||
| @inner_api_ns.doc( | |||
| responses={ | |||
| 200: "Summary generation successful", | |||
| 401: "Unauthorized - invalid API key", | |||
| 404: "Service not available", | |||
| } | |||
| ) | |||
| def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeSummary): | |||
| try: | |||
| return BaseBackwardsInvocationResponse( | |||
| @@ -285,40 +403,43 @@ class PluginInvokeSummaryApi(Resource): | |||
| return BaseBackwardsInvocationResponse(error=str(e)).model_dump() | |||
| @inner_api_ns.route("/upload/file/request") | |||
| class PluginUploadFileRequestApi(Resource): | |||
| @setup_required | |||
| @plugin_inner_api_only | |||
| @get_user_tenant | |||
| @plugin_data(payload_type=RequestRequestUploadFile) | |||
| @inner_api_ns.doc("plugin_upload_file_request") | |||
| @inner_api_ns.doc(description="Request signed URL for file upload through plugin interface") | |||
| @inner_api_ns.doc( | |||
| responses={ | |||
| 200: "Signed URL generated successfully", | |||
| 401: "Unauthorized - invalid API key", | |||
| 404: "Service not available", | |||
| } | |||
| ) | |||
| def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestRequestUploadFile): | |||
| # generate signed url | |||
| url = get_signed_file_url_for_plugin(payload.filename, payload.mimetype, tenant_model.id, user_model.id) | |||
| return BaseBackwardsInvocationResponse(data={"url": url}).model_dump() | |||
| @inner_api_ns.route("/fetch/app/info") | |||
| class PluginFetchAppInfoApi(Resource): | |||
| @setup_required | |||
| @plugin_inner_api_only | |||
| @get_user_tenant | |||
| @plugin_data(payload_type=RequestFetchAppInfo) | |||
| @inner_api_ns.doc("plugin_fetch_app_info") | |||
| @inner_api_ns.doc(description="Fetch application information through plugin interface") | |||
| @inner_api_ns.doc( | |||
| responses={ | |||
| 200: "App information retrieved successfully", | |||
| 401: "Unauthorized - invalid API key", | |||
| 404: "Service not available", | |||
| } | |||
| ) | |||
| def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestFetchAppInfo): | |||
| return BaseBackwardsInvocationResponse( | |||
| data=PluginAppBackwardsInvocation.fetch_app_info(payload.app_id, tenant_model.id) | |||
| ).model_dump() | |||
| api.add_resource(PluginInvokeLLMApi, "/invoke/llm") | |||
| api.add_resource(PluginInvokeLLMWithStructuredOutputApi, "/invoke/llm/structured-output") | |||
| api.add_resource(PluginInvokeTextEmbeddingApi, "/invoke/text-embedding") | |||
| api.add_resource(PluginInvokeRerankApi, "/invoke/rerank") | |||
| api.add_resource(PluginInvokeTTSApi, "/invoke/tts") | |||
| api.add_resource(PluginInvokeSpeech2TextApi, "/invoke/speech2text") | |||
| api.add_resource(PluginInvokeModerationApi, "/invoke/moderation") | |||
| api.add_resource(PluginInvokeToolApi, "/invoke/tool") | |||
| api.add_resource(PluginInvokeParameterExtractorNodeApi, "/invoke/parameter-extractor") | |||
| api.add_resource(PluginInvokeQuestionClassifierNodeApi, "/invoke/question-classifier") | |||
| api.add_resource(PluginInvokeAppApi, "/invoke/app") | |||
| api.add_resource(PluginInvokeEncryptApi, "/invoke/encrypt") | |||
| api.add_resource(PluginInvokeSummaryApi, "/invoke/summary") | |||
| api.add_resource(PluginUploadFileRequestApi, "/upload/file/request") | |||
| api.add_resource(PluginFetchAppInfoApi, "/fetch/app/info") | |||
| @@ -3,7 +3,7 @@ import json | |||
| from flask_restx import Resource, reqparse | |||
| from controllers.console.wraps import setup_required | |||
| from controllers.inner_api import api | |||
| from controllers.inner_api import inner_api_ns | |||
| from controllers.inner_api.wraps import enterprise_inner_api_only | |||
| from events.tenant_event import tenant_was_created | |||
| from extensions.ext_database import db | |||
| @@ -11,9 +11,19 @@ from models.account import Account | |||
| from services.account_service import TenantService | |||
| @inner_api_ns.route("/enterprise/workspace") | |||
| class EnterpriseWorkspace(Resource): | |||
| @setup_required | |||
| @enterprise_inner_api_only | |||
| @inner_api_ns.doc("create_enterprise_workspace") | |||
| @inner_api_ns.doc(description="Create a new enterprise workspace with owner assignment") | |||
| @inner_api_ns.doc( | |||
| responses={ | |||
| 200: "Workspace created successfully", | |||
| 401: "Unauthorized - invalid API key", | |||
| 404: "Owner account not found or service not available", | |||
| } | |||
| ) | |||
| def post(self): | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("name", type=str, required=True, location="json") | |||
| @@ -44,9 +54,19 @@ class EnterpriseWorkspace(Resource): | |||
| } | |||
| @inner_api_ns.route("/enterprise/workspace/ownerless") | |||
| class EnterpriseWorkspaceNoOwnerEmail(Resource): | |||
| @setup_required | |||
| @enterprise_inner_api_only | |||
| @inner_api_ns.doc("create_enterprise_workspace_ownerless") | |||
| @inner_api_ns.doc(description="Create a new enterprise workspace without initial owner assignment") | |||
| @inner_api_ns.doc( | |||
| responses={ | |||
| 200: "Workspace created successfully", | |||
| 401: "Unauthorized - invalid API key", | |||
| 404: "Service not available", | |||
| } | |||
| ) | |||
| def post(self): | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("name", type=str, required=True, location="json") | |||
| @@ -71,7 +91,3 @@ class EnterpriseWorkspaceNoOwnerEmail(Resource): | |||
| "message": "enterprise workspace created.", | |||
| "tenant": resp, | |||
| } | |||
| api.add_resource(EnterpriseWorkspace, "/enterprise/workspace") | |||
| api.add_resource(EnterpriseWorkspaceNoOwnerEmail, "/enterprise/workspace/ownerless") | |||
| @@ -10,6 +10,7 @@ from controllers.service_api.wraps import validate_app_token | |||
| from extensions.ext_redis import redis_client | |||
| from fields.annotation_fields import annotation_fields, build_annotation_model | |||
| from libs.login import current_user | |||
| from models.account import Account | |||
| from models.model import App | |||
| from services.annotation_service import AppAnnotationService | |||
| @@ -163,6 +164,7 @@ class AnnotationUpdateDeleteApi(Resource): | |||
| @service_api_ns.marshal_with(build_annotation_model(service_api_ns)) | |||
| def put(self, app_model: App, annotation_id): | |||
| """Update an existing annotation.""" | |||
| assert isinstance(current_user, Account) | |||
| if not current_user.is_editor: | |||
| raise Forbidden() | |||
| @@ -185,6 +187,8 @@ class AnnotationUpdateDeleteApi(Resource): | |||
| @validate_app_token | |||
| def delete(self, app_model: App, annotation_id): | |||
| """Delete an annotation.""" | |||
| assert isinstance(current_user, Account) | |||
| if not current_user.is_editor: | |||
| raise Forbidden() | |||
| @@ -29,6 +29,8 @@ from services.errors.audio import ( | |||
| UnsupportedAudioTypeServiceError, | |||
| ) | |||
| logger = logging.getLogger(__name__) | |||
| @service_api_ns.route("/audio-to-text") | |||
| class AudioApi(Resource): | |||
| @@ -57,7 +59,7 @@ class AudioApi(Resource): | |||
| return response | |||
| except services.errors.app_model_config.AppModelConfigBrokenError: | |||
| logging.exception("App model config broken.") | |||
| logger.exception("App model config broken.") | |||
| raise AppUnavailableError() | |||
| except NoAudioUploadedServiceError: | |||
| raise NoAudioUploadedError() | |||
| @@ -78,7 +80,7 @@ class AudioApi(Resource): | |||
| except ValueError as e: | |||
| raise e | |||
| except Exception as e: | |||
| logging.exception("internal server error.") | |||
| logger.exception("internal server error.") | |||
| raise InternalServerError() | |||
| @@ -121,7 +123,7 @@ class TextApi(Resource): | |||
| return response | |||
| except services.errors.app_model_config.AppModelConfigBrokenError: | |||
| logging.exception("App model config broken.") | |||
| logger.exception("App model config broken.") | |||
| raise AppUnavailableError() | |||
| except NoAudioUploadedServiceError: | |||
| raise NoAudioUploadedError() | |||
| @@ -142,5 +144,5 @@ class TextApi(Resource): | |||
| except ValueError as e: | |||
| raise e | |||
| except Exception as e: | |||
| logging.exception("internal server error.") | |||
| logger.exception("internal server error.") | |||
| raise InternalServerError() | |||
| @@ -33,6 +33,9 @@ from services.app_generate_service import AppGenerateService | |||
| from services.errors.app import IsDraftWorkflowError, WorkflowIdFormatError, WorkflowNotFoundError | |||
| from services.errors.llm import InvokeRateLimitError | |||
| logger = logging.getLogger(__name__) | |||
| # Define parser for completion API | |||
| completion_parser = reqparse.RequestParser() | |||
| completion_parser.add_argument( | |||
| @@ -118,7 +121,7 @@ class CompletionApi(Resource): | |||
| except services.errors.conversation.ConversationCompletedError: | |||
| raise ConversationCompletedError() | |||
| except services.errors.app_model_config.AppModelConfigBrokenError: | |||
| logging.exception("App model config broken.") | |||
| logger.exception("App model config broken.") | |||
| raise AppUnavailableError() | |||
| except ProviderTokenNotInitError as ex: | |||
| raise ProviderNotInitializeError(ex.description) | |||
| @@ -131,7 +134,7 @@ class CompletionApi(Resource): | |||
| except ValueError as e: | |||
| raise e | |||
| except Exception: | |||
| logging.exception("internal server error.") | |||
| logger.exception("internal server error.") | |||
| raise InternalServerError() | |||
| @@ -209,7 +212,7 @@ class ChatApi(Resource): | |||
| except services.errors.conversation.ConversationCompletedError: | |||
| raise ConversationCompletedError() | |||
| except services.errors.app_model_config.AppModelConfigBrokenError: | |||
| logging.exception("App model config broken.") | |||
| logger.exception("App model config broken.") | |||
| raise AppUnavailableError() | |||
| except ProviderTokenNotInitError as ex: | |||
| raise ProviderNotInitializeError(ex.description) | |||
| @@ -224,7 +227,7 @@ class ChatApi(Resource): | |||
| except ValueError as e: | |||
| raise e | |||
| except Exception: | |||
| logging.exception("internal server error.") | |||
| logger.exception("internal server error.") | |||
| raise InternalServerError() | |||
| @@ -22,6 +22,9 @@ from services.errors.message import ( | |||
| ) | |||
| from services.message_service import MessageService | |||
| logger = logging.getLogger(__name__) | |||
| # Define parsers for message APIs | |||
| message_list_parser = reqparse.RequestParser() | |||
| message_list_parser.add_argument( | |||
| @@ -216,7 +219,7 @@ class MessageSuggestedApi(Resource): | |||
| except SuggestedQuestionsAfterAnswerDisabledError: | |||
| raise BadRequest("Suggested Questions Is Disabled.") | |||
| except Exception: | |||
| logging.exception("internal server error.") | |||
| logger.exception("internal server error.") | |||
| raise InternalServerError() | |||
| return {"result": "success", "data": questions} | |||
| @@ -174,7 +174,7 @@ class WorkflowRunApi(Resource): | |||
| except ValueError as e: | |||
| raise e | |||
| except Exception: | |||
| logging.exception("internal server error.") | |||
| logger.exception("internal server error.") | |||
| raise InternalServerError() | |||
| @@ -239,7 +239,7 @@ class WorkflowRunByIdApi(Resource): | |||
| except ValueError as e: | |||
| raise e | |||
| except Exception: | |||
| logging.exception("internal server error.") | |||
| logger.exception("internal server error.") | |||
| raise InternalServerError() | |||
| @@ -18,6 +18,7 @@ from core.provider_manager import ProviderManager | |||
| from fields.dataset_fields import dataset_detail_fields | |||
| from fields.tag_fields import build_dataset_tag_fields | |||
| from libs.login import current_user | |||
| from models.account import Account | |||
| from models.dataset import Dataset, DatasetPermissionEnum | |||
| from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService | |||
| from services.entities.knowledge_entities.knowledge_entities import RetrievalModel | |||
| @@ -213,7 +214,10 @@ class DatasetListApi(DatasetApiResource): | |||
| ) | |||
| # check embedding setting | |||
| provider_manager = ProviderManager() | |||
| configurations = provider_manager.get_configurations(tenant_id=current_user.current_tenant_id) | |||
| assert isinstance(current_user, Account) | |||
| cid = current_user.current_tenant_id | |||
| assert cid is not None | |||
| configurations = provider_manager.get_configurations(tenant_id=cid) | |||
| embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True) | |||
| @@ -266,6 +270,7 @@ class DatasetListApi(DatasetApiResource): | |||
| ) | |||
| try: | |||
| assert isinstance(current_user, Account) | |||
| dataset = DatasetService.create_empty_dataset( | |||
| tenant_id=tenant_id, | |||
| name=args["name"], | |||
| @@ -319,7 +324,10 @@ class DatasetApi(DatasetApiResource): | |||
| # check embedding setting | |||
| provider_manager = ProviderManager() | |||
| configurations = provider_manager.get_configurations(tenant_id=current_user.current_tenant_id) | |||
| assert isinstance(current_user, Account) | |||
| cid = current_user.current_tenant_id | |||
| assert cid is not None | |||
| configurations = provider_manager.get_configurations(tenant_id=cid) | |||
| embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True) | |||
| @@ -391,6 +399,7 @@ class DatasetApi(DatasetApiResource): | |||
| raise NotFound("Dataset not found.") | |||
| result_data = marshal(dataset, dataset_detail_fields) | |||
| assert isinstance(current_user, Account) | |||
| tenant_id = current_user.current_tenant_id | |||
| if data.get("partial_member_list") and data.get("permission") == "partial_members": | |||
| @@ -532,7 +541,10 @@ class DatasetTagsApi(DatasetApiResource): | |||
| @service_api_ns.marshal_with(build_dataset_tag_fields(service_api_ns)) | |||
| def get(self, _, dataset_id): | |||
| """Get all knowledge type tags.""" | |||
| tags = TagService.get_tags("knowledge", current_user.current_tenant_id) | |||
| assert isinstance(current_user, Account) | |||
| cid = current_user.current_tenant_id | |||
| assert cid is not None | |||
| tags = TagService.get_tags("knowledge", cid) | |||
| return tags, 200 | |||
| @@ -550,6 +562,7 @@ class DatasetTagsApi(DatasetApiResource): | |||
| @validate_dataset_token | |||
| def post(self, _, dataset_id): | |||
| """Add a knowledge type tag.""" | |||
| assert isinstance(current_user, Account) | |||
| if not (current_user.is_editor or current_user.is_dataset_editor): | |||
| raise Forbidden() | |||
| @@ -573,6 +586,7 @@ class DatasetTagsApi(DatasetApiResource): | |||
| @service_api_ns.marshal_with(build_dataset_tag_fields(service_api_ns)) | |||
| @validate_dataset_token | |||
| def patch(self, _, dataset_id): | |||
| assert isinstance(current_user, Account) | |||
| if not (current_user.is_editor or current_user.is_dataset_editor): | |||
| raise Forbidden() | |||
| @@ -599,6 +613,7 @@ class DatasetTagsApi(DatasetApiResource): | |||
| @validate_dataset_token | |||
| def delete(self, _, dataset_id): | |||
| """Delete a knowledge type tag.""" | |||
| assert isinstance(current_user, Account) | |||
| if not current_user.is_editor: | |||
| raise Forbidden() | |||
| args = tag_delete_parser.parse_args() | |||
| @@ -622,6 +637,7 @@ class DatasetTagBindingApi(DatasetApiResource): | |||
| @validate_dataset_token | |||
| def post(self, _, dataset_id): | |||
| # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator | |||
| assert isinstance(current_user, Account) | |||
| if not (current_user.is_editor or current_user.is_dataset_editor): | |||
| raise Forbidden() | |||
| @@ -647,6 +663,7 @@ class DatasetTagUnbindingApi(DatasetApiResource): | |||
| @validate_dataset_token | |||
| def post(self, _, dataset_id): | |||
| # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator | |||
| assert isinstance(current_user, Account) | |||
| if not (current_user.is_editor or current_user.is_dataset_editor): | |||
| raise Forbidden() | |||
| @@ -672,6 +689,8 @@ class DatasetTagsBindingStatusApi(DatasetApiResource): | |||
| def get(self, _, *args, **kwargs): | |||
| """Get all knowledge type tags.""" | |||
| dataset_id = kwargs.get("dataset_id") | |||
| assert isinstance(current_user, Account) | |||
| assert current_user.current_tenant_id is not None | |||
| tags = TagService.get_tags_by_target_id("knowledge", current_user.current_tenant_id, str(dataset_id)) | |||
| tags_list = [{"id": tag.id, "name": tag.name} for tag in tags] | |||
| response = {"data": tags_list, "total": len(tags)} | |||
| @@ -16,6 +16,8 @@ from services.enterprise.enterprise_service import EnterpriseService | |||
| from services.feature_service import FeatureService | |||
| from services.webapp_auth_service import WebAppAuthService | |||
| logger = logging.getLogger(__name__) | |||
| class AppParameterApi(WebApiResource): | |||
| """Resource for app variables.""" | |||
| @@ -92,7 +94,7 @@ class AppWebAuthPermission(Resource): | |||
| except Unauthorized: | |||
| raise | |||
| except Exception: | |||
| logging.exception("Unexpected error during auth verification") | |||
| logger.exception("Unexpected error during auth verification") | |||
| raise | |||
| features = FeatureService.get_system_features() | |||
| @@ -28,6 +28,8 @@ from services.errors.audio import ( | |||
| UnsupportedAudioTypeServiceError, | |||
| ) | |||
| logger = logging.getLogger(__name__) | |||
| class AudioApi(WebApiResource): | |||
| def post(self, app_model: App, end_user): | |||
| @@ -38,7 +40,7 @@ class AudioApi(WebApiResource): | |||
| return response | |||
| except services.errors.app_model_config.AppModelConfigBrokenError: | |||
| logging.exception("App model config broken.") | |||
| logger.exception("App model config broken.") | |||
| raise AppUnavailableError() | |||
| except NoAudioUploadedServiceError: | |||
| raise NoAudioUploadedError() | |||
| @@ -59,7 +61,7 @@ class AudioApi(WebApiResource): | |||
| except ValueError as e: | |||
| raise e | |||
| except Exception as e: | |||
| logging.exception("Failed to handle post request to AudioApi") | |||
| logger.exception("Failed to handle post request to AudioApi") | |||
| raise InternalServerError() | |||
| @@ -84,7 +86,7 @@ class TextApi(WebApiResource): | |||
| return response | |||
| except services.errors.app_model_config.AppModelConfigBrokenError: | |||
| logging.exception("App model config broken.") | |||
| logger.exception("App model config broken.") | |||
| raise AppUnavailableError() | |||
| except NoAudioUploadedServiceError: | |||
| raise NoAudioUploadedError() | |||
| @@ -105,7 +107,7 @@ class TextApi(WebApiResource): | |||
| except ValueError as e: | |||
| raise e | |||
| except Exception as e: | |||
| logging.exception("Failed to handle post request to TextApi") | |||
| logger.exception("Failed to handle post request to TextApi") | |||
| raise InternalServerError() | |||
| @@ -31,6 +31,8 @@ from models.model import AppMode | |||
| from services.app_generate_service import AppGenerateService | |||
| from services.errors.llm import InvokeRateLimitError | |||
| logger = logging.getLogger(__name__) | |||
| # define completion api for user | |||
| class CompletionApi(WebApiResource): | |||
| @@ -61,7 +63,7 @@ class CompletionApi(WebApiResource): | |||
| except services.errors.conversation.ConversationCompletedError: | |||
| raise ConversationCompletedError() | |||
| except services.errors.app_model_config.AppModelConfigBrokenError: | |||
| logging.exception("App model config broken.") | |||
| logger.exception("App model config broken.") | |||
| raise AppUnavailableError() | |||
| except ProviderTokenNotInitError as ex: | |||
| raise ProviderNotInitializeError(ex.description) | |||
| @@ -74,7 +76,7 @@ class CompletionApi(WebApiResource): | |||
| except ValueError as e: | |||
| raise e | |||
| except Exception as e: | |||
| logging.exception("internal server error.") | |||
| logger.exception("internal server error.") | |||
| raise InternalServerError() | |||
| @@ -119,7 +121,7 @@ class ChatApi(WebApiResource): | |||
| except services.errors.conversation.ConversationCompletedError: | |||
| raise ConversationCompletedError() | |||
| except services.errors.app_model_config.AppModelConfigBrokenError: | |||
| logging.exception("App model config broken.") | |||
| logger.exception("App model config broken.") | |||
| raise AppUnavailableError() | |||
| except ProviderTokenNotInitError as ex: | |||
| raise ProviderNotInitializeError(ex.description) | |||
| @@ -134,7 +136,7 @@ class ChatApi(WebApiResource): | |||
| except ValueError as e: | |||
| raise e | |||
| except Exception as e: | |||
| logging.exception("internal server error.") | |||
| logger.exception("internal server error.") | |||
| raise InternalServerError() | |||
| @@ -7,13 +7,14 @@ from sqlalchemy import select | |||
| from sqlalchemy.orm import Session | |||
| from controllers.console.auth.error import ( | |||
| AuthenticationFailedError, | |||
| EmailCodeError, | |||
| EmailPasswordResetLimitError, | |||
| InvalidEmailError, | |||
| InvalidTokenError, | |||
| PasswordMismatchError, | |||
| ) | |||
| from controllers.console.error import AccountNotFound, EmailSendIpLimitError | |||
| from controllers.console.error import EmailSendIpLimitError | |||
| from controllers.console.wraps import email_password_login_enabled, only_edition_enterprise, setup_required | |||
| from controllers.web import api | |||
| from extensions.ext_database import db | |||
| @@ -46,7 +47,7 @@ class ForgotPasswordSendEmailApi(Resource): | |||
| account = session.execute(select(Account).filter_by(email=args["email"])).scalar_one_or_none() | |||
| token = None | |||
| if account is None: | |||
| raise AccountNotFound() | |||
| raise AuthenticationFailedError() | |||
| else: | |||
| token = AccountService.send_reset_password_email(account=account, email=args["email"], language=language) | |||
| @@ -131,7 +132,7 @@ class ForgotPasswordResetApi(Resource): | |||
| if account: | |||
| self._update_existing_account(account, password_hashed, salt, session) | |||
| else: | |||
| raise AccountNotFound() | |||
| raise AuthenticationFailedError() | |||
| return {"result": "success"} | |||
| @@ -2,8 +2,12 @@ from flask_restx import Resource, reqparse | |||
| from jwt import InvalidTokenError # type: ignore | |||
| import services | |||
| from controllers.console.auth.error import EmailCodeError, EmailOrPasswordMismatchError, InvalidEmailError | |||
| from controllers.console.error import AccountBannedError, AccountNotFound | |||
| from controllers.console.auth.error import ( | |||
| AuthenticationFailedError, | |||
| EmailCodeError, | |||
| InvalidEmailError, | |||
| ) | |||
| from controllers.console.error import AccountBannedError | |||
| from controllers.console.wraps import only_edition_enterprise, setup_required | |||
| from controllers.web import api | |||
| from libs.helper import email | |||
| @@ -29,9 +33,9 @@ class LoginApi(Resource): | |||
| except services.errors.account.AccountLoginError: | |||
| raise AccountBannedError() | |||
| except services.errors.account.AccountPasswordError: | |||
| raise EmailOrPasswordMismatchError() | |||
| raise AuthenticationFailedError() | |||
| except services.errors.account.AccountNotFoundError: | |||
| raise AccountNotFound() | |||
| raise AuthenticationFailedError() | |||
| token = WebAppAuthService.login(account=account) | |||
| return {"result": "success", "data": {"access_token": token}} | |||
| @@ -63,7 +67,7 @@ class EmailCodeLoginSendEmailApi(Resource): | |||
| account = WebAppAuthService.get_user_through_email(args["email"]) | |||
| if account is None: | |||
| raise AccountNotFound() | |||
| raise AuthenticationFailedError() | |||
| else: | |||
| token = WebAppAuthService.send_email_code_login_email(account=account, language=language) | |||
| @@ -95,7 +99,7 @@ class EmailCodeLoginApi(Resource): | |||
| WebAppAuthService.revoke_email_code_login_token(args["token"]) | |||
| account = WebAppAuthService.get_user_through_email(user_email) | |||
| if not account: | |||
| raise AccountNotFound() | |||
| raise AuthenticationFailedError() | |||
| token = WebAppAuthService.login(account=account) | |||
| AccountService.reset_login_error_rate_limit(args["email"]) | |||
| @@ -35,6 +35,8 @@ from services.errors.message import ( | |||
| ) | |||
| from services.message_service import MessageService | |||
| logger = logging.getLogger(__name__) | |||
| class MessageListApi(WebApiResource): | |||
| message_fields = { | |||
| @@ -145,7 +147,7 @@ class MessageMoreLikeThisApi(WebApiResource): | |||
| except ValueError as e: | |||
| raise e | |||
| except Exception: | |||
| logging.exception("internal server error.") | |||
| logger.exception("internal server error.") | |||
| raise InternalServerError() | |||
| @@ -176,7 +178,7 @@ class MessageSuggestedQuestionApi(WebApiResource): | |||
| except InvokeError as e: | |||
| raise CompletionRequestError(e.description) | |||
| except Exception: | |||
| logging.exception("internal server error.") | |||
| logger.exception("internal server error.") | |||
| raise InternalServerError() | |||
| return {"data": questions} | |||
| @@ -62,7 +62,7 @@ class WorkflowRunApi(WebApiResource): | |||
| except ValueError as e: | |||
| raise e | |||
| except Exception: | |||
| logging.exception("internal server error.") | |||
| logger.exception("internal server error.") | |||
| raise InternalServerError() | |||
| @@ -3,6 +3,17 @@ import re | |||
| from core.app.app_config.entities import ExternalDataVariableEntity, VariableEntity, VariableEntityType | |||
| from core.external_data_tool.factory import ExternalDataToolFactory | |||
| _ALLOWED_VARIABLE_ENTITY_TYPE = frozenset( | |||
| [ | |||
| VariableEntityType.TEXT_INPUT, | |||
| VariableEntityType.SELECT, | |||
| VariableEntityType.PARAGRAPH, | |||
| VariableEntityType.NUMBER, | |||
| VariableEntityType.EXTERNAL_DATA_TOOL, | |||
| VariableEntityType.CHECKBOX, | |||
| ] | |||
| ) | |||
| class BasicVariablesConfigManager: | |||
| @classmethod | |||
| @@ -47,6 +58,7 @@ class BasicVariablesConfigManager: | |||
| VariableEntityType.PARAGRAPH, | |||
| VariableEntityType.NUMBER, | |||
| VariableEntityType.SELECT, | |||
| VariableEntityType.CHECKBOX, | |||
| }: | |||
| variable = variables[variable_type] | |||
| variable_entities.append( | |||
| @@ -96,8 +108,17 @@ class BasicVariablesConfigManager: | |||
| variables = [] | |||
| for item in config["user_input_form"]: | |||
| key = list(item.keys())[0] | |||
| if key not in {"text-input", "select", "paragraph", "number", "external_data_tool"}: | |||
| raise ValueError("Keys in user_input_form list can only be 'text-input', 'paragraph' or 'select'") | |||
| # if key not in {"text-input", "select", "paragraph", "number", "external_data_tool"}: | |||
| if key not in { | |||
| VariableEntityType.TEXT_INPUT, | |||
| VariableEntityType.SELECT, | |||
| VariableEntityType.PARAGRAPH, | |||
| VariableEntityType.NUMBER, | |||
| VariableEntityType.EXTERNAL_DATA_TOOL, | |||
| VariableEntityType.CHECKBOX, | |||
| }: | |||
| allowed_keys = ", ".join(i.value for i in _ALLOWED_VARIABLE_ENTITY_TYPE) | |||
| raise ValueError(f"Keys in user_input_form list can only be {allowed_keys}") | |||
| form_item = item[key] | |||
| if "label" not in form_item: | |||
| @@ -8,6 +8,8 @@ from core.app.entities.task_entities import AppBlockingResponse, AppStreamRespon | |||
| from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError | |||
| from core.model_runtime.errors.invoke import InvokeError | |||
| logger = logging.getLogger(__name__) | |||
| class AppGenerateResponseConverter(ABC): | |||
| _blocking_response_type: type[AppBlockingResponse] | |||
| @@ -120,7 +122,7 @@ class AppGenerateResponseConverter(ABC): | |||
| if data: | |||
| data.setdefault("message", getattr(e, "description", str(e))) | |||
| else: | |||
| logging.error(e) | |||
| logger.error(e) | |||
| data = { | |||
| "code": "internal_server_error", | |||
| "message": "Internal Server Error, please contact support.", | |||
| @@ -103,18 +103,23 @@ class BaseAppGenerator: | |||
| f"(type '{variable_entity.type}') {variable_entity.variable} in input form must be a string" | |||
| ) | |||
| if variable_entity.type == VariableEntityType.NUMBER and isinstance(value, str): | |||
| # handle empty string case | |||
| if not value.strip(): | |||
| return None | |||
| # may raise ValueError if user_input_value is not a valid number | |||
| try: | |||
| if "." in value: | |||
| return float(value) | |||
| else: | |||
| return int(value) | |||
| except ValueError: | |||
| raise ValueError(f"{variable_entity.variable} in input form must be a valid number") | |||
| if variable_entity.type == VariableEntityType.NUMBER: | |||
| if isinstance(value, (int, float)): | |||
| return value | |||
| elif isinstance(value, str): | |||
| # handle empty string case | |||
| if not value.strip(): | |||
| return None | |||
| # may raise ValueError if user_input_value is not a valid number | |||
| try: | |||
| if "." in value: | |||
| return float(value) | |||
| else: | |||
| return int(value) | |||
| except ValueError: | |||
| raise ValueError(f"{variable_entity.variable} in input form must be a valid number") | |||
| else: | |||
| raise TypeError(f"expected value type int, float or str, got {type(value)}, value: {value}") | |||
| match variable_entity.type: | |||
| case VariableEntityType.SELECT: | |||
| @@ -144,6 +149,11 @@ class BaseAppGenerator: | |||
| raise ValueError( | |||
| f"{variable_entity.variable} in input form must be less than {variable_entity.max_length} files" | |||
| ) | |||
| case VariableEntityType.CHECKBOX: | |||
| if not isinstance(value, bool): | |||
| raise ValueError(f"{variable_entity.variable} in input form must be a valid boolean value") | |||
| case _: | |||
| raise AssertionError("this statement should be unreachable.") | |||
| return value | |||
| @@ -32,6 +32,8 @@ from extensions.ext_database import db | |||
| from models.model import AppMode, Conversation, MessageAnnotation, MessageFile | |||
| from services.annotation_service import AppAnnotationService | |||
| logger = logging.getLogger(__name__) | |||
| class MessageCycleManager: | |||
| def __init__( | |||
| @@ -98,7 +100,7 @@ class MessageCycleManager: | |||
| conversation.name = name | |||
| except Exception as e: | |||
| if dify_config.DEBUG: | |||
| logging.exception("generate conversation name failed, conversation_id: %s", conversation_id) | |||
| logger.exception("generate conversation name failed, conversation_id: %s", conversation_id) | |||
| pass | |||
| db.session.merge(conversation) | |||
| @@ -19,6 +19,7 @@ class ModelStatus(Enum): | |||
| QUOTA_EXCEEDED = "quota-exceeded" | |||
| NO_PERMISSION = "no-permission" | |||
| DISABLED = "disabled" | |||
| CREDENTIAL_REMOVED = "credential-removed" | |||
| class SimpleModelProviderEntity(BaseModel): | |||
| @@ -54,6 +55,7 @@ class ProviderModelWithStatusEntity(ProviderModel): | |||
| status: ModelStatus | |||
| load_balancing_enabled: bool = False | |||
| has_invalid_load_balancing_configs: bool = False | |||
| def raise_for_status(self) -> None: | |||
| """ | |||
| @@ -69,6 +69,15 @@ class QuotaConfiguration(BaseModel): | |||
| restrict_models: list[RestrictModel] = [] | |||
| class CredentialConfiguration(BaseModel): | |||
| """ | |||
| Model class for credential configuration. | |||
| """ | |||
| credential_id: str | |||
| credential_name: str | |||
| class SystemConfiguration(BaseModel): | |||
| """ | |||
| Model class for provider system configuration. | |||
| @@ -86,6 +95,9 @@ class CustomProviderConfiguration(BaseModel): | |||
| """ | |||
| credentials: dict | |||
| current_credential_id: Optional[str] = None | |||
| current_credential_name: Optional[str] = None | |||
| available_credentials: list[CredentialConfiguration] = [] | |||
| class CustomModelConfiguration(BaseModel): | |||
| @@ -95,7 +107,10 @@ class CustomModelConfiguration(BaseModel): | |||
| model: str | |||
| model_type: ModelType | |||
| credentials: dict | |||
| credentials: dict | None | |||
| current_credential_id: Optional[str] = None | |||
| current_credential_name: Optional[str] = None | |||
| available_model_credentials: list[CredentialConfiguration] = [] | |||
| # pydantic configs | |||
| model_config = ConfigDict(protected_namespaces=()) | |||
| @@ -118,6 +133,7 @@ class ModelLoadBalancingConfiguration(BaseModel): | |||
| id: str | |||
| name: str | |||
| credentials: dict | |||
| credential_source_type: str | None = None | |||
| class ModelSettings(BaseModel): | |||
| @@ -10,6 +10,8 @@ from pydantic import BaseModel | |||
| from core.helper.position_helper import sort_to_dict_by_position_map | |||
| logger = logging.getLogger(__name__) | |||
| class ExtensionModule(enum.Enum): | |||
| MODERATION = "moderation" | |||
| @@ -66,7 +68,7 @@ class Extensible: | |||
| # Check for extension module file | |||
| if (extension_name + ".py") not in file_names: | |||
| logging.warning("Missing %s.py file in %s, Skip.", extension_name, subdir_path) | |||
| logger.warning("Missing %s.py file in %s, Skip.", extension_name, subdir_path) | |||
| continue | |||
| # Check for builtin flag and position | |||
| @@ -95,7 +97,7 @@ class Extensible: | |||
| break | |||
| if not extension_class: | |||
| logging.warning("Missing subclass of %s in %s, Skip.", cls.__name__, module_name) | |||
| logger.warning("Missing subclass of %s in %s, Skip.", cls.__name__, module_name) | |||
| continue | |||
| # Load schema if not builtin | |||
| @@ -103,7 +105,7 @@ class Extensible: | |||
| if not builtin: | |||
| json_path = os.path.join(subdir_path, "schema.json") | |||
| if not os.path.exists(json_path): | |||
| logging.warning("Missing schema.json file in %s, Skip.", subdir_path) | |||
| logger.warning("Missing schema.json file in %s, Skip.", subdir_path) | |||
| continue | |||
| with open(json_path, encoding="utf-8") as f: | |||
| @@ -122,7 +124,7 @@ class Extensible: | |||
| ) | |||
| except Exception as e: | |||
| logging.exception("Error scanning extensions") | |||
| logger.exception("Error scanning extensions") | |||
| raise | |||
| # Sort extensions by position | |||
| @@ -17,6 +17,7 @@ def encrypt_token(tenant_id: str, token: str): | |||
| if not (tenant := db.session.query(Tenant).where(Tenant.id == tenant_id).first()): | |||
| raise ValueError(f"Tenant with id {tenant_id} not found") | |||
| assert tenant.encrypt_public_key is not None | |||
| encrypted_token = rsa.encrypt(token, tenant.encrypt_public_key) | |||
| return base64.b64encode(encrypted_token).decode() | |||
| @@ -4,6 +4,8 @@ import sys | |||
| from types import ModuleType | |||
| from typing import AnyStr | |||
| logger = logging.getLogger(__name__) | |||
| def import_module_from_source(*, module_name: str, py_file_path: AnyStr, use_lazy_loader: bool = False) -> ModuleType: | |||
| """ | |||
| @@ -30,7 +32,7 @@ def import_module_from_source(*, module_name: str, py_file_path: AnyStr, use_laz | |||
| spec.loader.exec_module(module) | |||
| return module | |||
| except Exception as e: | |||
| logging.exception("Failed to load module %s from script file '%s'", module_name, repr(py_file_path)) | |||
| logger.exception("Failed to load module %s from script file '%s'", module_name, repr(py_file_path)) | |||
| raise e | |||
| @@ -9,6 +9,8 @@ import httpx | |||
| from configs import dify_config | |||
| logger = logging.getLogger(__name__) | |||
| SSRF_DEFAULT_MAX_RETRIES = dify_config.SSRF_DEFAULT_MAX_RETRIES | |||
| HTTP_REQUEST_NODE_SSL_VERIFY = True # Default value for HTTP_REQUEST_NODE_SSL_VERIFY is True | |||
| @@ -73,12 +75,12 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): | |||
| if response.status_code not in STATUS_FORCELIST: | |||
| return response | |||
| else: | |||
| logging.warning( | |||
| logger.warning( | |||
| "Received status code %s for URL %s which is in the force list", response.status_code, url | |||
| ) | |||
| except httpx.RequestError as e: | |||
| logging.warning("Request to URL %s failed on attempt %s: %s", url, retries + 1, e) | |||
| logger.warning("Request to URL %s failed on attempt %s: %s", url, retries + 1, e) | |||
| if max_retries == 0: | |||
| raise | |||
| @@ -39,6 +39,8 @@ from models.dataset import Document as DatasetDocument | |||
| from models.model import UploadFile | |||
| from services.feature_service import FeatureService | |||
| logger = logging.getLogger(__name__) | |||
| class IndexingRunner: | |||
| def __init__(self): | |||
| @@ -90,9 +92,9 @@ class IndexingRunner: | |||
| dataset_document.stopped_at = naive_utc_now() | |||
| db.session.commit() | |||
| except ObjectDeletedError: | |||
| logging.warning("Document deleted, document id: %s", dataset_document.id) | |||
| logger.warning("Document deleted, document id: %s", dataset_document.id) | |||
| except Exception as e: | |||
| logging.exception("consume document failed") | |||
| logger.exception("consume document failed") | |||
| dataset_document.indexing_status = "error" | |||
| dataset_document.error = str(e) | |||
| dataset_document.stopped_at = naive_utc_now() | |||
| @@ -153,7 +155,7 @@ class IndexingRunner: | |||
| dataset_document.stopped_at = naive_utc_now() | |||
| db.session.commit() | |||
| except Exception as e: | |||
| logging.exception("consume document failed") | |||
| logger.exception("consume document failed") | |||
| dataset_document.indexing_status = "error" | |||
| dataset_document.error = str(e) | |||
| dataset_document.stopped_at = naive_utc_now() | |||
| @@ -228,7 +230,7 @@ class IndexingRunner: | |||
| dataset_document.stopped_at = naive_utc_now() | |||
| db.session.commit() | |||
| except Exception as e: | |||
| logging.exception("consume document failed") | |||
| logger.exception("consume document failed") | |||
| dataset_document.indexing_status = "error" | |||
| dataset_document.error = str(e) | |||
| dataset_document.stopped_at = naive_utc_now() | |||
| @@ -321,7 +323,7 @@ class IndexingRunner: | |||
| try: | |||
| storage.delete(image_file.key) | |||
| except Exception: | |||
| logging.exception( | |||
| logger.exception( | |||
| "Delete image_files failed while indexing_estimate, \ | |||
| image_upload_file_is: %s", | |||
| upload_file_id, | |||
| @@ -31,6 +31,8 @@ from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution | |||
| from core.workflow.graph_engine.entities.event import AgentLogEvent | |||
| from models import App, Message, WorkflowNodeExecutionModel, db | |||
| logger = logging.getLogger(__name__) | |||
| class LLMGenerator: | |||
| @classmethod | |||
| @@ -68,7 +70,7 @@ class LLMGenerator: | |||
| result_dict = json.loads(cleaned_answer) | |||
| answer = result_dict["Your Output"] | |||
| except json.JSONDecodeError as e: | |||
| logging.exception("Failed to generate name after answer, use query instead") | |||
| logger.exception("Failed to generate name after answer, use query instead") | |||
| answer = query | |||
| name = answer.strip() | |||
| @@ -125,7 +127,7 @@ class LLMGenerator: | |||
| except InvokeError: | |||
| questions = [] | |||
| except Exception: | |||
| logging.exception("Failed to generate suggested questions after answer") | |||
| logger.exception("Failed to generate suggested questions after answer") | |||
| questions = [] | |||
| return questions | |||
| @@ -173,7 +175,7 @@ class LLMGenerator: | |||
| error = str(e) | |||
| error_step = "generate rule config" | |||
| except Exception as e: | |||
| logging.exception("Failed to generate rule config, model: %s", model_config.get("name")) | |||
| logger.exception("Failed to generate rule config, model: %s", model_config.get("name")) | |||
| rule_config["error"] = str(e) | |||
| rule_config["error"] = f"Failed to {error_step}. Error: {error}" if error else "" | |||
| @@ -270,7 +272,7 @@ class LLMGenerator: | |||
| error_step = "generate conversation opener" | |||
| except Exception as e: | |||
| logging.exception("Failed to generate rule config, model: %s", model_config.get("name")) | |||
| logger.exception("Failed to generate rule config, model: %s", model_config.get("name")) | |||
| rule_config["error"] = str(e) | |||
| rule_config["error"] = f"Failed to {error_step}. Error: {error}" if error else "" | |||
| @@ -319,7 +321,7 @@ class LLMGenerator: | |||
| error = str(e) | |||
| return {"code": "", "language": code_language, "error": f"Failed to generate code. Error: {error}"} | |||
| except Exception as e: | |||
| logging.exception( | |||
| logger.exception( | |||
| "Failed to invoke LLM model, model: %s, language: %s", model_config.get("name"), code_language | |||
| ) | |||
| return {"code": "", "language": code_language, "error": f"An unexpected error occurred: {str(e)}"} | |||
| @@ -392,7 +394,7 @@ class LLMGenerator: | |||
| error = str(e) | |||
| return {"output": "", "error": f"Failed to generate JSON Schema. Error: {error}"} | |||
| except Exception as e: | |||
| logging.exception("Failed to invoke LLM model, model: %s", model_config.get("name")) | |||
| logger.exception("Failed to invoke LLM model, model: %s", model_config.get("name")) | |||
| return {"output": "", "error": f"An unexpected error occurred: {str(e)}"} | |||
| @staticmethod | |||
| @@ -570,5 +572,5 @@ class LLMGenerator: | |||
| error = str(e) | |||
| return {"error": f"Failed to generate code. Error: {error}"} | |||
| except Exception as e: | |||
| logging.exception("Failed to invoke LLM model, model: " + json.dumps(model_config.get("name")), exc_info=e) | |||
| logger.exception("Failed to invoke LLM model, model: %s", json.dumps(model_config.get("name")), exc_info=e) | |||
| return {"error": f"An unexpected error occurred: {str(e)}"} | |||
| @@ -152,7 +152,7 @@ class MCPClient: | |||
| # ExitStack will handle proper cleanup of all managed context managers | |||
| self._exit_stack.close() | |||
| except Exception as e: | |||
| logging.exception("Error during cleanup") | |||
| logger.exception("Error during cleanup") | |||
| raise ValueError(f"Error during cleanup: {e}") | |||
| finally: | |||
| self._session = None | |||
| @@ -31,6 +31,9 @@ from core.mcp.types import ( | |||
| SessionMessage, | |||
| ) | |||
| logger = logging.getLogger(__name__) | |||
| SendRequestT = TypeVar("SendRequestT", ClientRequest, ServerRequest) | |||
| SendResultT = TypeVar("SendResultT", ClientResult, ServerResult) | |||
| SendNotificationT = TypeVar("SendNotificationT", ClientNotification, ServerNotification) | |||
| @@ -366,7 +369,7 @@ class BaseSession( | |||
| self._handle_incoming(notification) | |||
| except Exception as e: | |||
| # For other validation errors, log and continue | |||
| logging.warning("Failed to validate notification: %s. Message was: %s", e, message.message.root) | |||
| logger.warning("Failed to validate notification: %s. Message was: %s", e, message.message.root) | |||
| else: # Response or error | |||
| response_queue = self._response_streams.get(message.message.root.id) | |||
| if response_queue is not None: | |||
| @@ -376,7 +379,7 @@ class BaseSession( | |||
| except queue.Empty: | |||
| continue | |||
| except Exception: | |||
| logging.exception("Error in message processing loop") | |||
| logger.exception("Error in message processing loop") | |||
| raise | |||
| def _received_request(self, responder: RequestResponder[ReceiveRequestT, SendResultT]) -> None: | |||
| @@ -201,7 +201,7 @@ class ModelProviderFactory: | |||
| return filtered_credentials | |||
| def get_model_schema( | |||
| self, *, provider: str, model_type: ModelType, model: str, credentials: dict | |||
| self, *, provider: str, model_type: ModelType, model: str, credentials: dict | None | |||
| ) -> AIModelEntity | None: | |||
| """ | |||
| Get model schema | |||
| @@ -100,14 +100,14 @@ class Moderation(Extensible, ABC): | |||
| if not inputs_config.get("preset_response"): | |||
| raise ValueError("inputs_config.preset_response is required") | |||
| if len(inputs_config.get("preset_response", 0)) > 100: | |||
| if len(inputs_config.get("preset_response", "0")) > 100: | |||
| raise ValueError("inputs_config.preset_response must be less than 100 characters") | |||
| if outputs_config_enabled: | |||
| if not outputs_config.get("preset_response"): | |||
| raise ValueError("outputs_config.preset_response is required") | |||
| if len(outputs_config.get("preset_response", 0)) > 100: | |||
| if len(outputs_config.get("preset_response", "0")) > 100: | |||
| raise ValueError("outputs_config.preset_response must be less than 100 characters") | |||
| @@ -306,7 +306,7 @@ class AliyunDataTrace(BaseTraceInstance): | |||
| node_span = self.build_workflow_task_span(trace_id, workflow_span_id, trace_info, node_execution) | |||
| return node_span | |||
| except Exception as e: | |||
| logging.debug("Error occurred in build_workflow_node_span: %s", e, exc_info=True) | |||
| logger.debug("Error occurred in build_workflow_node_span: %s", e, exc_info=True) | |||
| return None | |||
| def get_workflow_node_status(self, node_execution: WorkflowNodeExecution) -> Status: | |||
| @@ -37,6 +37,8 @@ from models.model import App, AppModelConfig, Conversation, Message, MessageFile | |||
| from models.workflow import WorkflowAppLog, WorkflowRun | |||
| from tasks.ops_trace_task import process_trace_tasks | |||
| logger = logging.getLogger(__name__) | |||
| class OpsTraceProviderConfigMap(dict[str, dict[str, Any]]): | |||
| def __getitem__(self, provider: str) -> dict[str, Any]: | |||
| @@ -287,7 +289,7 @@ class OpsTraceManager: | |||
| # create new tracing_instance and update the cache if it absent | |||
| tracing_instance = trace_instance(config_class(**decrypt_trace_config)) | |||
| cls.ops_trace_instances_cache[decrypt_trace_config_key] = tracing_instance | |||
| logging.info("new tracing_instance for app_id: %s", app_id) | |||
| logger.info("new tracing_instance for app_id: %s", app_id) | |||
| return tracing_instance | |||
| @classmethod | |||
| @@ -849,7 +851,7 @@ class TraceQueueManager: | |||
| trace_task.app_id = self.app_id | |||
| trace_manager_queue.put(trace_task) | |||
| except Exception as e: | |||
| logging.exception("Error adding trace task, trace_type %s", trace_task.trace_type) | |||
| logger.exception("Error adding trace task, trace_type %s", trace_task.trace_type) | |||
| finally: | |||
| self.start_timer() | |||
| @@ -868,7 +870,7 @@ class TraceQueueManager: | |||
| if tasks: | |||
| self.send_to_celery(tasks) | |||
| except Exception as e: | |||
| logging.exception("Error processing trace tasks") | |||
| logger.exception("Error processing trace tasks") | |||
| def start_timer(self): | |||
| global trace_manager_timer | |||
| @@ -154,7 +154,7 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation): | |||
| """ | |||
| workflow = app.workflow | |||
| if not workflow: | |||
| raise ValueError("") | |||
| raise ValueError("unexpected app type") | |||
| return WorkflowAppGenerator().generate( | |||
| app_model=app, | |||
| @@ -8,6 +8,7 @@ from core.plugin.entities.plugin_daemon import ( | |||
| ) | |||
| from core.plugin.entities.request import PluginInvokeContext | |||
| from core.plugin.impl.base import BasePluginClient | |||
| from core.plugin.utils.chunk_merger import merge_blob_chunks | |||
| class PluginAgentClient(BasePluginClient): | |||
| @@ -113,4 +114,4 @@ class PluginAgentClient(BasePluginClient): | |||
| "Content-Type": "application/json", | |||
| }, | |||
| ) | |||
| return response | |||
| return merge_blob_chunks(response) | |||
| @@ -141,11 +141,11 @@ class BasePluginClient: | |||
| response.raise_for_status() | |||
| except HTTPError as e: | |||
| msg = f"Failed to request plugin daemon, status: {e.response.status_code}, url: {path}" | |||
| logging.exception(msg) | |||
| logger.exception(msg) | |||
| raise e | |||
| except Exception as e: | |||
| msg = f"Failed to request plugin daemon, url: {path}" | |||
| logging.exception(msg) | |||
| logger.exception(msg) | |||
| raise ValueError(msg) from e | |||
| try: | |||
| @@ -158,7 +158,7 @@ class BasePluginClient: | |||
| f"Failed to parse response from plugin daemon to PluginDaemonBasicResponse [{str(type.__name__)}]," | |||
| f" url: {path}" | |||
| ) | |||
| logging.exception(msg) | |||
| logger.exception(msg) | |||
| raise ValueError(msg) | |||
| if rep.code != 0: | |||
| @@ -9,6 +9,7 @@ from core.plugin.entities.plugin_daemon import ( | |||
| PluginToolProviderEntity, | |||
| ) | |||
| from core.plugin.impl.base import BasePluginClient | |||
| from core.plugin.utils.chunk_merger import merge_blob_chunks | |||
| from core.schemas.resolver import resolve_dify_schema_refs | |||
| from core.tools.entities.tool_entities import CredentialType, ToolInvokeMessage, ToolParameter | |||
| @@ -123,61 +124,7 @@ class PluginToolManager(BasePluginClient): | |||
| }, | |||
| ) | |||
| class FileChunk: | |||
| """ | |||
| Only used for internal processing. | |||
| """ | |||
| bytes_written: int | |||
| total_length: int | |||
| data: bytearray | |||
| def __init__(self, total_length: int): | |||
| self.bytes_written = 0 | |||
| self.total_length = total_length | |||
| self.data = bytearray(total_length) | |||
| files: dict[str, FileChunk] = {} | |||
| for resp in response: | |||
| if resp.type == ToolInvokeMessage.MessageType.BLOB_CHUNK: | |||
| assert isinstance(resp.message, ToolInvokeMessage.BlobChunkMessage) | |||
| # Get blob chunk information | |||
| chunk_id = resp.message.id | |||
| total_length = resp.message.total_length | |||
| blob_data = resp.message.blob | |||
| is_end = resp.message.end | |||
| # Initialize buffer for this file if it doesn't exist | |||
| if chunk_id not in files: | |||
| files[chunk_id] = FileChunk(total_length) | |||
| # If this is the final chunk, yield a complete blob message | |||
| if is_end: | |||
| yield ToolInvokeMessage( | |||
| type=ToolInvokeMessage.MessageType.BLOB, | |||
| message=ToolInvokeMessage.BlobMessage(blob=files[chunk_id].data), | |||
| meta=resp.meta, | |||
| ) | |||
| else: | |||
| # Check if file is too large (30MB limit) | |||
| if files[chunk_id].bytes_written + len(blob_data) > 30 * 1024 * 1024: | |||
| # Delete the file if it's too large | |||
| del files[chunk_id] | |||
| # Skip yielding this message | |||
| raise ValueError("File is too large which reached the limit of 30MB") | |||
| # Check if single chunk is too large (8KB limit) | |||
| if len(blob_data) > 8192: | |||
| # Skip yielding this message | |||
| raise ValueError("File chunk is too large which reached the limit of 8KB") | |||
| # Append the blob data to the buffer | |||
| files[chunk_id].data[ | |||
| files[chunk_id].bytes_written : files[chunk_id].bytes_written + len(blob_data) | |||
| ] = blob_data | |||
| files[chunk_id].bytes_written += len(blob_data) | |||
| else: | |||
| yield resp | |||
| return merge_blob_chunks(response) | |||
| def validate_provider_credentials( | |||
| self, tenant_id: str, user_id: str, provider: str, credentials: dict[str, Any] | |||
| @@ -0,0 +1,92 @@ | |||
| from collections.abc import Generator | |||
| from dataclasses import dataclass, field | |||
| from typing import TypeVar, Union, cast | |||
| from core.agent.entities import AgentInvokeMessage | |||
| from core.tools.entities.tool_entities import ToolInvokeMessage | |||
| MessageType = TypeVar("MessageType", bound=Union[ToolInvokeMessage, AgentInvokeMessage]) | |||
| @dataclass | |||
| class FileChunk: | |||
| """ | |||
| Buffer for accumulating file chunks during streaming. | |||
| """ | |||
| total_length: int | |||
| bytes_written: int = field(default=0, init=False) | |||
| data: bytearray = field(init=False) | |||
| def __post_init__(self) -> None: | |||
| self.data = bytearray(self.total_length) | |||
| def merge_blob_chunks( | |||
| response: Generator[MessageType, None, None], | |||
| max_file_size: int = 30 * 1024 * 1024, | |||
| max_chunk_size: int = 8192, | |||
| ) -> Generator[MessageType, None, None]: | |||
| """ | |||
| Merge streaming blob chunks into complete blob messages. | |||
| This function processes a stream of plugin invoke messages, accumulating | |||
| BLOB_CHUNK messages by their ID until the final chunk is received, | |||
| then yielding a single complete BLOB message. | |||
| Args: | |||
| response: Generator yielding messages that may include blob chunks | |||
| max_file_size: Maximum allowed file size in bytes (default: 30MB) | |||
| max_chunk_size: Maximum allowed chunk size in bytes (default: 8KB) | |||
| Yields: | |||
| Messages from the response stream, with blob chunks merged into complete blobs | |||
| Raises: | |||
| ValueError: If file size exceeds max_file_size or chunk size exceeds max_chunk_size | |||
| """ | |||
| files: dict[str, FileChunk] = {} | |||
| for resp in response: | |||
| if resp.type == ToolInvokeMessage.MessageType.BLOB_CHUNK: | |||
| assert isinstance(resp.message, ToolInvokeMessage.BlobChunkMessage) | |||
| # Get blob chunk information | |||
| chunk_id = resp.message.id | |||
| total_length = resp.message.total_length | |||
| blob_data = resp.message.blob | |||
| is_end = resp.message.end | |||
| # Initialize buffer for this file if it doesn't exist | |||
| if chunk_id not in files: | |||
| files[chunk_id] = FileChunk(total_length) | |||
| # Check if file is too large (before appending) | |||
| if files[chunk_id].bytes_written + len(blob_data) > max_file_size: | |||
| # Delete the file if it's too large | |||
| del files[chunk_id] | |||
| raise ValueError(f"File is too large which reached the limit of {max_file_size / 1024 / 1024}MB") | |||
| # Check if single chunk is too large | |||
| if len(blob_data) > max_chunk_size: | |||
| raise ValueError(f"File chunk is too large which reached the limit of {max_chunk_size / 1024}KB") | |||
| # Append the blob data to the buffer | |||
| files[chunk_id].data[files[chunk_id].bytes_written : files[chunk_id].bytes_written + len(blob_data)] = ( | |||
| blob_data | |||
| ) | |||
| files[chunk_id].bytes_written += len(blob_data) | |||
| # If this is the final chunk, yield a complete blob message | |||
| if is_end: | |||
| # Create the appropriate message type based on the response type | |||
| message_class = type(resp) | |||
| merged_message = message_class( | |||
| type=ToolInvokeMessage.MessageType.BLOB, | |||
| message=ToolInvokeMessage.BlobMessage(blob=files[chunk_id].data[: files[chunk_id].bytes_written]), | |||
| meta=resp.meta, | |||
| ) | |||
| yield cast(MessageType, merged_message) | |||
| # Clean up the buffer | |||
| del files[chunk_id] | |||
| else: | |||
| yield resp | |||
| @@ -12,6 +12,7 @@ from configs import dify_config | |||
| from core.entities.model_entities import DefaultModelEntity, DefaultModelProviderEntity | |||
| from core.entities.provider_configuration import ProviderConfiguration, ProviderConfigurations, ProviderModelBundle | |||
| from core.entities.provider_entities import ( | |||
| CredentialConfiguration, | |||
| CustomConfiguration, | |||
| CustomModelConfiguration, | |||
| CustomProviderConfiguration, | |||
| @@ -40,7 +41,9 @@ from extensions.ext_redis import redis_client | |||
| from models.provider import ( | |||
| LoadBalancingModelConfig, | |||
| Provider, | |||
| ProviderCredential, | |||
| ProviderModel, | |||
| ProviderModelCredential, | |||
| ProviderModelSetting, | |||
| ProviderType, | |||
| TenantDefaultModel, | |||
| @@ -488,6 +491,61 @@ class ProviderManager: | |||
| return provider_name_to_provider_load_balancing_model_configs_dict | |||
| @staticmethod | |||
| def get_provider_available_credentials(tenant_id: str, provider_name: str) -> list[CredentialConfiguration]: | |||
| """ | |||
| Get provider all credentials. | |||
| :param tenant_id: workspace id | |||
| :param provider_name: provider name | |||
| :return: | |||
| """ | |||
| with Session(db.engine, expire_on_commit=False) as session: | |||
| stmt = ( | |||
| select(ProviderCredential) | |||
| .where(ProviderCredential.tenant_id == tenant_id, ProviderCredential.provider_name == provider_name) | |||
| .order_by(ProviderCredential.created_at.desc()) | |||
| ) | |||
| available_credentials = session.scalars(stmt).all() | |||
| return [ | |||
| CredentialConfiguration(credential_id=credential.id, credential_name=credential.credential_name) | |||
| for credential in available_credentials | |||
| ] | |||
| @staticmethod | |||
| def get_provider_model_available_credentials( | |||
| tenant_id: str, provider_name: str, model_name: str, model_type: str | |||
| ) -> list[CredentialConfiguration]: | |||
| """ | |||
| Get provider custom model all credentials. | |||
| :param tenant_id: workspace id | |||
| :param provider_name: provider name | |||
| :param model_name: model name | |||
| :param model_type: model type | |||
| :return: | |||
| """ | |||
| with Session(db.engine, expire_on_commit=False) as session: | |||
| stmt = ( | |||
| select(ProviderModelCredential) | |||
| .where( | |||
| ProviderModelCredential.tenant_id == tenant_id, | |||
| ProviderModelCredential.provider_name == provider_name, | |||
| ProviderModelCredential.model_name == model_name, | |||
| ProviderModelCredential.model_type == model_type, | |||
| ) | |||
| .order_by(ProviderModelCredential.created_at.desc()) | |||
| ) | |||
| available_credentials = session.scalars(stmt).all() | |||
| return [ | |||
| CredentialConfiguration(credential_id=credential.id, credential_name=credential.credential_name) | |||
| for credential in available_credentials | |||
| ] | |||
| @staticmethod | |||
| def _init_trial_provider_records( | |||
| tenant_id: str, provider_name_to_provider_records_dict: dict[str, list[Provider]] | |||
| @@ -590,9 +648,6 @@ class ProviderManager: | |||
| if provider_record.provider_type == ProviderType.SYSTEM.value: | |||
| continue | |||
| if not provider_record.encrypted_config: | |||
| continue | |||
| custom_provider_record = provider_record | |||
| # Get custom provider credentials | |||
| @@ -611,8 +666,8 @@ class ProviderManager: | |||
| try: | |||
| # fix origin data | |||
| if custom_provider_record.encrypted_config is None: | |||
| raise ValueError("No credentials found") | |||
| if not custom_provider_record.encrypted_config.startswith("{"): | |||
| provider_credentials = {} | |||
| elif not custom_provider_record.encrypted_config.startswith("{"): | |||
| provider_credentials = {"openai_api_key": custom_provider_record.encrypted_config} | |||
| else: | |||
| provider_credentials = json.loads(custom_provider_record.encrypted_config) | |||
| @@ -637,7 +692,14 @@ class ProviderManager: | |||
| else: | |||
| provider_credentials = cached_provider_credentials | |||
| custom_provider_configuration = CustomProviderConfiguration(credentials=provider_credentials) | |||
| custom_provider_configuration = CustomProviderConfiguration( | |||
| credentials=provider_credentials, | |||
| current_credential_name=custom_provider_record.credential_name, | |||
| current_credential_id=custom_provider_record.credential_id, | |||
| available_credentials=self.get_provider_available_credentials( | |||
| tenant_id, custom_provider_record.provider_name | |||
| ), | |||
| ) | |||
| # Get provider model credential secret variables | |||
| model_credential_secret_variables = self._extract_secret_variables( | |||
| @@ -649,8 +711,12 @@ class ProviderManager: | |||
| # Get custom provider model credentials | |||
| custom_model_configurations = [] | |||
| for provider_model_record in provider_model_records: | |||
| if not provider_model_record.encrypted_config: | |||
| continue | |||
| available_model_credentials = self.get_provider_model_available_credentials( | |||
| tenant_id, | |||
| provider_model_record.provider_name, | |||
| provider_model_record.model_name, | |||
| provider_model_record.model_type, | |||
| ) | |||
| provider_model_credentials_cache = ProviderCredentialsCache( | |||
| tenant_id=tenant_id, identity_id=provider_model_record.id, cache_type=ProviderCredentialsCacheType.MODEL | |||
| @@ -659,7 +725,7 @@ class ProviderManager: | |||
| # Get cached provider model credentials | |||
| cached_provider_model_credentials = provider_model_credentials_cache.get() | |||
| if not cached_provider_model_credentials: | |||
| if not cached_provider_model_credentials and provider_model_record.encrypted_config: | |||
| try: | |||
| provider_model_credentials = json.loads(provider_model_record.encrypted_config) | |||
| except JSONDecodeError: | |||
| @@ -688,6 +754,9 @@ class ProviderManager: | |||
| model=provider_model_record.model_name, | |||
| model_type=ModelType.value_of(provider_model_record.model_type), | |||
| credentials=provider_model_credentials, | |||
| current_credential_id=provider_model_record.credential_id, | |||
| current_credential_name=provider_model_record.credential_name, | |||
| available_model_credentials=available_model_credentials, | |||
| ) | |||
| ) | |||
| @@ -899,6 +968,18 @@ class ProviderManager: | |||
| load_balancing_model_config.model_name == provider_model_setting.model_name | |||
| and load_balancing_model_config.model_type == provider_model_setting.model_type | |||
| ): | |||
| if load_balancing_model_config.name == "__delete__": | |||
| # to calculate current model whether has invalidate lb configs | |||
| load_balancing_configs.append( | |||
| ModelLoadBalancingConfiguration( | |||
| id=load_balancing_model_config.id, | |||
| name=load_balancing_model_config.name, | |||
| credentials={}, | |||
| credential_source_type=load_balancing_model_config.credential_source_type, | |||
| ) | |||
| ) | |||
| continue | |||
| if not load_balancing_model_config.enabled: | |||
| continue | |||
| @@ -955,6 +1036,7 @@ class ProviderManager: | |||
| id=load_balancing_model_config.id, | |||
| name=load_balancing_model_config.name, | |||
| credentials=provider_model_credentials, | |||
| credential_source_type=load_balancing_model_config.credential_source_type, | |||
| ) | |||
| ) | |||
| @@ -259,8 +259,16 @@ class MilvusVector(BaseVector): | |||
| """ | |||
| Search for documents by full-text search (if hybrid search is enabled). | |||
| """ | |||
| if not self._hybrid_search_enabled or not self.field_exists(Field.SPARSE_VECTOR.value): | |||
| logger.warning("Full-text search is not supported in current Milvus version (requires >= 2.5.0)") | |||
| if not self._hybrid_search_enabled: | |||
| logger.warning( | |||
| "Full-text search is disabled: set MILVUS_ENABLE_HYBRID_SEARCH=true (requires Milvus >= 2.5.0)." | |||
| ) | |||
| return [] | |||
| if not self.field_exists(Field.SPARSE_VECTOR.value): | |||
| logger.warning( | |||
| "Full-text search unavailable: collection missing 'sparse_vector' field; " | |||
| "recreate the collection after enabling MILVUS_ENABLE_HYBRID_SEARCH to add BM25 sparse index." | |||
| ) | |||
| return [] | |||
| document_ids_filter = kwargs.get("document_ids_filter") | |||
| filter = "" | |||
| @@ -15,6 +15,8 @@ from core.rag.embedding.embedding_base import Embeddings | |||
| from core.rag.models.document import Document | |||
| from models.dataset import Dataset | |||
| logger = logging.getLogger(__name__) | |||
| class MyScaleConfig(BaseModel): | |||
| host: str | |||
| @@ -53,7 +55,7 @@ class MyScaleVector(BaseVector): | |||
| return self.add_texts(documents=texts, embeddings=embeddings, **kwargs) | |||
| def _create_collection(self, dimension: int): | |||
| logging.info("create MyScale collection %s with dimension %s", self._collection_name, dimension) | |||
| logger.info("create MyScale collection %s with dimension %s", self._collection_name, dimension) | |||
| self._client.command(f"CREATE DATABASE IF NOT EXISTS {self._config.database}") | |||
| fts_params = f"('{self._config.fts_params}')" if self._config.fts_params else "" | |||
| sql = f""" | |||
| @@ -151,7 +153,7 @@ class MyScaleVector(BaseVector): | |||
| for r in self._client.query(sql).named_results() | |||
| ] | |||
| except Exception as e: | |||
| logging.exception("\033[91m\033[1m%s\033[0m \033[95m%s\033[0m", type(e), str(e)) # noqa:TRY401 | |||
| logger.exception("\033[91m\033[1m%s\033[0m \033[95m%s\033[0m", type(e), str(e)) # noqa:TRY401 | |||
| return [] | |||
| def delete(self) -> None: | |||
| @@ -188,14 +188,17 @@ class OracleVector(BaseVector): | |||
| def text_exists(self, id: str) -> bool: | |||
| with self._get_connection() as conn: | |||
| with conn.cursor() as cur: | |||
| cur.execute(f"SELECT id FROM {self.table_name} WHERE id = '%s'" % (id,)) | |||
| cur.execute(f"SELECT id FROM {self.table_name} WHERE id = :1", (id,)) | |||
| return cur.fetchone() is not None | |||
| conn.close() | |||
| def get_by_ids(self, ids: list[str]) -> list[Document]: | |||
| if not ids: | |||
| return [] | |||
| with self._get_connection() as conn: | |||
| with conn.cursor() as cur: | |||
| cur.execute(f"SELECT meta, text FROM {self.table_name} WHERE id IN %s", (tuple(ids),)) | |||
| placeholders = ", ".join(f":{i + 1}" for i in range(len(ids))) | |||
| cur.execute(f"SELECT meta, text FROM {self.table_name} WHERE id IN ({placeholders})", ids) | |||
| docs = [] | |||
| for record in cur: | |||
| docs.append(Document(page_content=record[1], metadata=record[0])) | |||
| @@ -208,14 +211,15 @@ class OracleVector(BaseVector): | |||
| return | |||
| with self._get_connection() as conn: | |||
| with conn.cursor() as cur: | |||
| cur.execute(f"DELETE FROM {self.table_name} WHERE id IN %s" % (tuple(ids),)) | |||
| placeholders = ", ".join(f":{i + 1}" for i in range(len(ids))) | |||
| cur.execute(f"DELETE FROM {self.table_name} WHERE id IN ({placeholders})", ids) | |||
| conn.commit() | |||
| conn.close() | |||
| def delete_by_metadata_field(self, key: str, value: str) -> None: | |||
| with self._get_connection() as conn: | |||
| with conn.cursor() as cur: | |||
| cur.execute(f"DELETE FROM {self.table_name} WHERE meta->>%s = %s", (key, value)) | |||
| cur.execute(f"DELETE FROM {self.table_name} WHERE JSON_VALUE(meta, '$." + key + "') = :1", (value,)) | |||
| conn.commit() | |||
| conn.close() | |||
| @@ -227,12 +231,20 @@ class OracleVector(BaseVector): | |||
| :param top_k: The number of nearest neighbors to return, default is 5. | |||
| :return: List of Documents that are nearest to the query vector. | |||
| """ | |||
| # Validate and sanitize top_k to prevent SQL injection | |||
| top_k = kwargs.get("top_k", 4) | |||
| if not isinstance(top_k, int) or top_k <= 0 or top_k > 10000: | |||
| top_k = 4 # Use default if invalid | |||
| document_ids_filter = kwargs.get("document_ids_filter") | |||
| where_clause = "" | |||
| params = [numpy.array(query_vector)] | |||
| if document_ids_filter: | |||
| document_ids = ", ".join(f"'{id}'" for id in document_ids_filter) | |||
| where_clause = f"WHERE metadata->>'document_id' in ({document_ids})" | |||
| placeholders = ", ".join(f":{i + 2}" for i in range(len(document_ids_filter))) | |||
| where_clause = f"WHERE JSON_VALUE(meta, '$.document_id') IN ({placeholders})" | |||
| params.extend(document_ids_filter) | |||
| with self._get_connection() as conn: | |||
| conn.inputtypehandler = self.input_type_handler | |||
| conn.outputtypehandler = self.output_type_handler | |||
| @@ -241,7 +253,7 @@ class OracleVector(BaseVector): | |||
| f"""SELECT meta, text, vector_distance(embedding,(select to_vector(:1) from dual),cosine) | |||
| AS distance FROM {self.table_name} | |||
| {where_clause} ORDER BY distance fetch first {top_k} rows only""", | |||
| [numpy.array(query_vector)], | |||
| params, | |||
| ) | |||
| docs = [] | |||
| score_threshold = float(kwargs.get("score_threshold") or 0.0) | |||
| @@ -259,7 +271,10 @@ class OracleVector(BaseVector): | |||
| import nltk # type: ignore | |||
| from nltk.corpus import stopwords # type: ignore | |||
| # Validate and sanitize top_k to prevent SQL injection | |||
| top_k = kwargs.get("top_k", 5) | |||
| if not isinstance(top_k, int) or top_k <= 0 or top_k > 10000: | |||
| top_k = 5 # Use default if invalid | |||
| # just not implement fetch by score_threshold now, may be later | |||
| score_threshold = float(kwargs.get("score_threshold") or 0.0) | |||
| if len(query) > 0: | |||
| @@ -297,14 +312,21 @@ class OracleVector(BaseVector): | |||
| with conn.cursor() as cur: | |||
| document_ids_filter = kwargs.get("document_ids_filter") | |||
| where_clause = "" | |||
| params: dict[str, Any] = {"kk": " ACCUM ".join(entities)} | |||
| if document_ids_filter: | |||
| document_ids = ", ".join(f"'{id}'" for id in document_ids_filter) | |||
| where_clause = f" AND metadata->>'document_id' in ({document_ids}) " | |||
| placeholders = [] | |||
| for i, doc_id in enumerate(document_ids_filter): | |||
| param_name = f"doc_id_{i}" | |||
| placeholders.append(f":{param_name}") | |||
| params[param_name] = doc_id | |||
| where_clause = f" AND JSON_VALUE(meta, '$.document_id') IN ({', '.join(placeholders)}) " | |||
| cur.execute( | |||
| f"""select meta, text, embedding FROM {self.table_name} | |||
| WHERE CONTAINS(text, :kk, 1) > 0 {where_clause} | |||
| order by score(1) desc fetch first {top_k} rows only""", | |||
| kk=" ACCUM ".join(entities), | |||
| params, | |||
| ) | |||
| docs = [] | |||
| for record in cur: | |||
| @@ -19,6 +19,8 @@ from core.rag.models.document import Document | |||
| from extensions.ext_redis import redis_client | |||
| from models.dataset import Dataset | |||
| logger = logging.getLogger(__name__) | |||
| class PGVectorConfig(BaseModel): | |||
| host: str | |||
| @@ -155,7 +157,7 @@ class PGVector(BaseVector): | |||
| cur.execute(f"DELETE FROM {self.table_name} WHERE id IN %s", (tuple(ids),)) | |||
| except psycopg2.errors.UndefinedTable: | |||
| # table not exists | |||
| logging.warning("Table %s not found, skipping delete operation.", self.table_name) | |||
| logger.warning("Table %s not found, skipping delete operation.", self.table_name) | |||
| return | |||
| except Exception as e: | |||
| raise e | |||
| @@ -17,6 +17,8 @@ from core.rag.models.document import Document | |||
| from extensions.ext_redis import redis_client | |||
| from models import Dataset | |||
| logger = logging.getLogger(__name__) | |||
| class TableStoreConfig(BaseModel): | |||
| access_key_id: Optional[str] = None | |||
| @@ -145,7 +147,7 @@ class TableStoreVector(BaseVector): | |||
| with redis_client.lock(lock_name, timeout=20): | |||
| collection_exist_cache_key = f"vector_indexing_{self._collection_name}" | |||
| if redis_client.get(collection_exist_cache_key): | |||
| logging.info("Collection %s already exists.", self._collection_name) | |||
| logger.info("Collection %s already exists.", self._collection_name) | |||
| return | |||
| self._create_table_if_not_exist() | |||
| @@ -155,7 +157,7 @@ class TableStoreVector(BaseVector): | |||
| def _create_table_if_not_exist(self) -> None: | |||
| table_list = self._tablestore_client.list_table() | |||
| if self._table_name in table_list: | |||
| logging.info("Tablestore system table[%s] already exists", self._table_name) | |||
| logger.info("Tablestore system table[%s] already exists", self._table_name) | |||
| return None | |||
| schema_of_primary_key = [("id", "STRING")] | |||
| @@ -163,12 +165,12 @@ class TableStoreVector(BaseVector): | |||
| table_options = tablestore.TableOptions() | |||
| reserved_throughput = tablestore.ReservedThroughput(tablestore.CapacityUnit(0, 0)) | |||
| self._tablestore_client.create_table(table_meta, table_options, reserved_throughput) | |||
| logging.info("Tablestore create table[%s] successfully.", self._table_name) | |||
| logger.info("Tablestore create table[%s] successfully.", self._table_name) | |||
| def _create_search_index_if_not_exist(self, dimension: int) -> None: | |||
| search_index_list = self._tablestore_client.list_search_index(table_name=self._table_name) | |||
| if self._index_name in [t[1] for t in search_index_list]: | |||
| logging.info("Tablestore system index[%s] already exists", self._index_name) | |||
| logger.info("Tablestore system index[%s] already exists", self._index_name) | |||
| return None | |||
| field_schemas = [ | |||
| @@ -206,20 +208,20 @@ class TableStoreVector(BaseVector): | |||
| index_meta = tablestore.SearchIndexMeta(field_schemas) | |||
| self._tablestore_client.create_search_index(self._table_name, self._index_name, index_meta) | |||
| logging.info("Tablestore create system index[%s] successfully.", self._index_name) | |||
| logger.info("Tablestore create system index[%s] successfully.", self._index_name) | |||
| def _delete_table_if_exist(self): | |||
| search_index_list = self._tablestore_client.list_search_index(table_name=self._table_name) | |||
| for resp_tuple in search_index_list: | |||
| self._tablestore_client.delete_search_index(resp_tuple[0], resp_tuple[1]) | |||
| logging.info("Tablestore delete index[%s] successfully.", self._index_name) | |||
| logger.info("Tablestore delete index[%s] successfully.", self._index_name) | |||
| self._tablestore_client.delete_table(self._table_name) | |||
| logging.info("Tablestore delete system table[%s] successfully.", self._index_name) | |||
| logger.info("Tablestore delete system table[%s] successfully.", self._index_name) | |||
| def _delete_search_index(self) -> None: | |||
| self._tablestore_client.delete_search_index(self._table_name, self._index_name) | |||
| logging.info("Tablestore delete index[%s] successfully.", self._index_name) | |||
| logger.info("Tablestore delete index[%s] successfully.", self._index_name) | |||
| def _write_row(self, primary_key: str, attributes: dict[str, Any]) -> None: | |||
| pk = [("id", primary_key)] | |||
| @@ -83,14 +83,14 @@ class TiDBVector(BaseVector): | |||
| self._dimension = 1536 | |||
| def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): | |||
| logger.info("create collection and add texts, collection_name: " + self._collection_name) | |||
| logger.info("create collection and add texts, collection_name: %s", self._collection_name) | |||
| self._create_collection(len(embeddings[0])) | |||
| self.add_texts(texts, embeddings) | |||
| self._dimension = len(embeddings[0]) | |||
| pass | |||
| def _create_collection(self, dimension: int): | |||
| logger.info("_create_collection, collection_name " + self._collection_name) | |||
| logger.info("_create_collection, collection_name %s", self._collection_name) | |||
| lock_name = f"vector_indexing_lock_{self._collection_name}" | |||
| with redis_client.lock(lock_name, timeout=20): | |||
| collection_exist_cache_key = f"vector_indexing_{self._collection_name}" | |||
| @@ -75,7 +75,7 @@ class CacheEmbedding(Embeddings): | |||
| except IntegrityError: | |||
| db.session.rollback() | |||
| except Exception: | |||
| logging.exception("Failed transform embedding") | |||
| logger.exception("Failed transform embedding") | |||
| cache_embeddings = [] | |||
| try: | |||
| for i, n_embedding in zip(embedding_queue_indices, embedding_queue_embeddings): | |||
| @@ -95,7 +95,7 @@ class CacheEmbedding(Embeddings): | |||
| db.session.rollback() | |||
| except Exception as ex: | |||
| db.session.rollback() | |||
| logger.exception("Failed to embed documents: %s") | |||
| logger.exception("Failed to embed documents") | |||
| raise ex | |||
| return text_embeddings | |||
| @@ -122,7 +122,7 @@ class CacheEmbedding(Embeddings): | |||
| raise ValueError("Normalized embedding is nan please try again") | |||
| except Exception as ex: | |||
| if dify_config.DEBUG: | |||
| logging.exception("Failed to embed query text '%s...(%s chars)'", text[:10], len(text)) | |||
| logger.exception("Failed to embed query text '%s...(%s chars)'", text[:10], len(text)) | |||
| raise ex | |||
| try: | |||
| @@ -136,7 +136,7 @@ class CacheEmbedding(Embeddings): | |||
| redis_client.setex(embedding_cache_key, 600, encoded_str) | |||
| except Exception as ex: | |||
| if dify_config.DEBUG: | |||
| logging.exception( | |||
| logger.exception( | |||
| "Failed to add embedding to redis for the text '%s...(%s chars)'", text[:10], len(text) | |||
| ) | |||
| raise ex | |||
| @@ -26,6 +26,8 @@ from models.dataset import Dataset | |||
| from models.dataset import Document as DatasetDocument | |||
| from services.entities.knowledge_entities.knowledge_entities import Rule | |||
| logger = logging.getLogger(__name__) | |||
| class QAIndexProcessor(BaseIndexProcessor): | |||
| def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]: | |||
| @@ -215,7 +217,7 @@ class QAIndexProcessor(BaseIndexProcessor): | |||
| qa_documents.append(qa_document) | |||
| format_documents.extend(qa_documents) | |||
| except Exception as e: | |||
| logging.exception("Failed to format qa document") | |||
| logger.exception("Failed to format qa document") | |||
| all_qa_documents.extend(format_documents) | |||
| @@ -39,9 +39,16 @@ class WeightRerankRunner(BaseRerankRunner): | |||
| unique_documents = [] | |||
| doc_ids = set() | |||
| for document in documents: | |||
| if document.metadata is not None and document.metadata["doc_id"] not in doc_ids: | |||
| if ( | |||
| document.provider == "dify" | |||
| and document.metadata is not None | |||
| and document.metadata["doc_id"] not in doc_ids | |||
| ): | |||
| doc_ids.add(document.metadata["doc_id"]) | |||
| unique_documents.append(document) | |||
| else: | |||
| if document not in unique_documents: | |||
| unique_documents.append(document) | |||
| documents = unique_documents | |||
| @@ -275,35 +275,30 @@ class ApiTool(Tool): | |||
| if files: | |||
| headers.pop("Content-Type", None) | |||
| if method in { | |||
| "get", | |||
| "head", | |||
| "post", | |||
| "put", | |||
| "delete", | |||
| "patch", | |||
| "options", | |||
| "GET", | |||
| "POST", | |||
| "PUT", | |||
| "PATCH", | |||
| "DELETE", | |||
| "HEAD", | |||
| "OPTIONS", | |||
| }: | |||
| response: httpx.Response = getattr(ssrf_proxy, method.lower())( | |||
| url, | |||
| params=params, | |||
| headers=headers, | |||
| cookies=cookies, | |||
| data=body, | |||
| files=files, | |||
| timeout=API_TOOL_DEFAULT_TIMEOUT, | |||
| follow_redirects=True, | |||
| ) | |||
| return response | |||
| else: | |||
| _METHOD_MAP = { | |||
| "get": ssrf_proxy.get, | |||
| "head": ssrf_proxy.head, | |||
| "post": ssrf_proxy.post, | |||
| "put": ssrf_proxy.put, | |||
| "delete": ssrf_proxy.delete, | |||
| "patch": ssrf_proxy.patch, | |||
| } | |||
| method_lc = method.lower() | |||
| if method_lc not in _METHOD_MAP: | |||
| raise ValueError(f"Invalid http method {method}") | |||
| response: httpx.Response = _METHOD_MAP[ | |||
| method_lc | |||
| ]( # https://discuss.python.org/t/type-inference-for-function-return-types/42926 | |||
| url, | |||
| params=params, | |||
| headers=headers, | |||
| cookies=cookies, | |||
| data=body, | |||
| files=files, | |||
| timeout=API_TOOL_DEFAULT_TIMEOUT, | |||
| follow_redirects=True, | |||
| ) | |||
| return response | |||
| def _convert_body_property_any_of( | |||
| self, property: dict[str, Any], value: Any, any_of: list[dict[str, Any]], max_recursive=10 | |||