Co-authored-by: StyleZhang <jasonapring2015@outlook.com> Co-authored-by: Garfield Dai <dai.hai@foxmail.com> Co-authored-by: chenhe <guchenhe@gmail.com> Co-authored-by: jyong <jyong@dify.ai> Co-authored-by: Joel <iamjoel007@gmail.com> Co-authored-by: Yeuoly <admin@srmxy.cn>tags/0.4.0
| @@ -0,0 +1,58 @@ | |||
| name: Run Pytest | |||
| on: | |||
| pull_request: | |||
| branches: | |||
| - main | |||
| push: | |||
| branches: | |||
| - deploy/dev | |||
| - feat/model-runtime | |||
| jobs: | |||
| test: | |||
| runs-on: ubuntu-latest | |||
| env: | |||
| OPENAI_API_KEY: sk-IamNotARealKeyJustForMockTestKawaiiiiiiiiii | |||
| AZURE_OPENAI_API_BASE: https://difyai-openai.openai.azure.com | |||
| AZURE_OPENAI_API_KEY: xxxxb1707exxxxxxxxxxaaxxxxxf94 | |||
| ANTHROPIC_API_KEY: sk-ant-api11-IamNotARealKeyJustForMockTestKawaiiiiiiiiii-NotBaka-ASkksz | |||
| CHATGLM_API_BASE: http://a.abc.com:11451 | |||
| XINFERENCE_SERVER_URL: http://a.abc.com:11451 | |||
| XINFERENCE_GENERATION_MODEL_UID: generate | |||
| XINFERENCE_CHAT_MODEL_UID: chat | |||
| XINFERENCE_EMBEDDINGS_MODEL_UID: embedding | |||
| XINFERENCE_RERANK_MODEL_UID: rerank | |||
| GOOGLE_API_KEY: abcdefghijklmnopqrstuvwxyz | |||
| HUGGINGFACE_API_KEY: hf-awuwuwuwuwuwuwuwuwuwuwuwuwuwuwuwuwu | |||
| HUGGINGFACE_TEXT_GEN_ENDPOINT_URL: a | |||
| HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL: b | |||
| HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL: c | |||
| MOCK_SWITCH: true | |||
| steps: | |||
| - name: Checkout code | |||
| uses: actions/checkout@v2 | |||
| - name: Set up Python | |||
| uses: actions/setup-python@v2 | |||
| with: | |||
| python-version: '3.10' | |||
| - name: Cache pip dependencies | |||
| uses: actions/cache@v2 | |||
| with: | |||
| path: ~/.cache/pip | |||
| key: ${{ runner.os }}-pip-${{ hashFiles('api/requirements.txt') }} | |||
| restore-keys: ${{ runner.os }}-pip- | |||
| - name: Install dependencies | |||
| run: | | |||
| python -m pip install --upgrade pip | |||
| pip install pytest | |||
| pip install -r api/requirements.txt | |||
| - name: Run pytest | |||
| run: pytest api/tests/integration_tests/model_runtime/anthropic api/tests/integration_tests/model_runtime/azure_openai api/tests/integration_tests/model_runtime/openai api/tests/integration_tests/model_runtime/chatglm api/tests/integration_tests/model_runtime/google api/tests/integration_tests/model_runtime/xinference api/tests/integration_tests/model_runtime/huggingface_hub/test_llm.py | |||
| @@ -1,38 +0,0 @@ | |||
| name: Run Pytest | |||
| on: | |||
| pull_request: | |||
| branches: | |||
| - main | |||
| push: | |||
| branches: | |||
| - deploy/dev | |||
| jobs: | |||
| test: | |||
| runs-on: ubuntu-latest | |||
| steps: | |||
| - name: Checkout code | |||
| uses: actions/checkout@v2 | |||
| - name: Set up Python | |||
| uses: actions/setup-python@v2 | |||
| with: | |||
| python-version: '3.10' | |||
| - name: Cache pip dependencies | |||
| uses: actions/cache@v2 | |||
| with: | |||
| path: ~/.cache/pip | |||
| key: ${{ runner.os }}-pip-${{ hashFiles('api/requirements.txt') }} | |||
| restore-keys: ${{ runner.os }}-pip- | |||
| - name: Install dependencies | |||
| run: | | |||
| python -m pip install --upgrade pip | |||
| pip install pytest | |||
| pip install -r api/requirements.txt | |||
| - name: Run pytest | |||
| run: pytest api/tests/unit_tests | |||
| @@ -55,6 +55,11 @@ Did you have an issue, like a merge conflict, or don't know how to open a pull r | |||
| Stuck somewhere? Have any questions? Join the [Discord Community Server](https://discord.gg/j3XRWSPBf7). We are here to help! | |||
| ### Provider Integrations | |||
| If you see a model provider not yet supported by Dify that you'd like to use, follow these [steps](api/core/model_runtime/README.md) to submit a PR. | |||
| ### i18n (Internationalization) Support | |||
| We are looking for contributors to help with translations in other languages. If you are interested in helping, please join the [Discord Community Server](https://discord.gg/AhzKf7dNgk) and let us know. | |||
| @@ -4,6 +4,21 @@ | |||
| // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 | |||
| "version": "0.2.0", | |||
| "configurations": [ | |||
| { | |||
| "name": "Python: Celery", | |||
| "type": "python", | |||
| "request": "launch", | |||
| "module": "celery", | |||
| "justMyCode": true, | |||
| "args": ["-A", "app.celery", "worker", "-P", "gevent", "-c", "1", "--loglevel", "info", "-Q", "dataset,generation,mail"], | |||
| "envFile": "${workspaceFolder}/.env", | |||
| "env": { | |||
| "FLASK_APP": "app.py", | |||
| "FLASK_DEBUG": "1", | |||
| "GEVENT_SUPPORT": "True" | |||
| }, | |||
| "console": "integratedTerminal" | |||
| }, | |||
| { | |||
| "name": "Python: Flask", | |||
| "type": "python", | |||
| @@ -34,9 +34,6 @@ RUN apt-get update \ | |||
| COPY --from=base /pkg /usr/local | |||
| COPY . /app/api/ | |||
| RUN python -c "from transformers import GPT2TokenizerFast; GPT2TokenizerFast.from_pretrained('gpt2')" | |||
| ENV TRANSFORMERS_OFFLINE true | |||
| COPY docker/entrypoint.sh /entrypoint.sh | |||
| RUN chmod +x /entrypoint.sh | |||
| @@ -6,9 +6,12 @@ from werkzeug.exceptions import Unauthorized | |||
| if not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true': | |||
| from gevent import monkey | |||
| monkey.patch_all() | |||
| if os.environ.get("VECTOR_STORE") == 'milvus': | |||
| import grpc.experimental.gevent | |||
| grpc.experimental.gevent.init_gevent() | |||
| # if os.environ.get("VECTOR_STORE") == 'milvus': | |||
| import grpc.experimental.gevent | |||
| grpc.experimental.gevent.init_gevent() | |||
| import langchain | |||
| langchain.verbose = True | |||
| import time | |||
| import logging | |||
| @@ -18,9 +21,8 @@ import threading | |||
| from flask import Flask, request, Response | |||
| from flask_cors import CORS | |||
| from core.model_providers.providers import hosted | |||
| from extensions import ext_celery, ext_sentry, ext_redis, ext_login, ext_migrate, \ | |||
| ext_database, ext_storage, ext_mail, ext_code_based_extension | |||
| ext_database, ext_storage, ext_mail, ext_code_based_extension, ext_hosting_provider | |||
| from extensions.ext_database import db | |||
| from extensions.ext_login import login_manager | |||
| @@ -79,8 +81,6 @@ def create_app(test_config=None) -> Flask: | |||
| register_blueprints(app) | |||
| register_commands(app) | |||
| hosted.init_app(app) | |||
| return app | |||
| @@ -95,6 +95,7 @@ def initialize_extensions(app): | |||
| ext_celery.init_app(app) | |||
| ext_login.init_app(app) | |||
| ext_mail.init_app(app) | |||
| ext_hosting_provider.init_app(app) | |||
| ext_sentry.init_app(app) | |||
| @@ -105,13 +106,18 @@ def load_user_from_request(request_from_flask_login): | |||
| if request.blueprint == 'console': | |||
| # Check if the user_id contains a dot, indicating the old format | |||
| auth_header = request.headers.get('Authorization', '') | |||
| if ' ' not in auth_header: | |||
| raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.') | |||
| auth_scheme, auth_token = auth_header.split(None, 1) | |||
| auth_scheme = auth_scheme.lower() | |||
| if auth_scheme != 'bearer': | |||
| raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.') | |||
| if not auth_header: | |||
| auth_token = request.args.get('_token') | |||
| if not auth_token: | |||
| raise Unauthorized('Invalid Authorization token.') | |||
| else: | |||
| if ' ' not in auth_header: | |||
| raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.') | |||
| auth_scheme, auth_token = auth_header.split(None, 1) | |||
| auth_scheme = auth_scheme.lower() | |||
| if auth_scheme != 'bearer': | |||
| raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.') | |||
| decoded = PassportService().verify(auth_token) | |||
| user_id = decoded.get('user_id') | |||
| @@ -12,16 +12,12 @@ import qdrant_client | |||
| from qdrant_client.http.models import TextIndexParams, TextIndexType, TokenizerType | |||
| from tqdm import tqdm | |||
| from flask import current_app, Flask | |||
| from langchain.embeddings import OpenAIEmbeddings | |||
| from werkzeug.exceptions import NotFound | |||
| from core.embedding.cached_embedding import CacheEmbedding | |||
| from core.index.index import IndexBuilder | |||
| from core.model_providers.model_factory import ModelFactory | |||
| from core.model_providers.models.embedding.openai_embedding import OpenAIEmbedding | |||
| from core.model_providers.models.entity.model_params import ModelType | |||
| from core.model_providers.providers.hosted import hosted_model_providers | |||
| from core.model_providers.providers.openai_provider import OpenAIProvider | |||
| from core.model_manager import ModelManager | |||
| from core.model_runtime.entities.model_entities import ModelType | |||
| from libs.password import password_pattern, valid_password, hash_password | |||
| from libs.helper import email as email_validate | |||
| from extensions.ext_database import db | |||
| @@ -327,6 +323,8 @@ def create_qdrant_indexes(): | |||
| except NotFound: | |||
| break | |||
| model_manager = ModelManager() | |||
| page += 1 | |||
| for dataset in datasets: | |||
| if dataset.index_struct_dict: | |||
| @@ -334,19 +332,23 @@ def create_qdrant_indexes(): | |||
| try: | |||
| click.echo('Create dataset qdrant index: {}'.format(dataset.id)) | |||
| try: | |||
| embedding_model = ModelFactory.get_embedding_model( | |||
| embedding_model = model_manager.get_model_instance( | |||
| tenant_id=dataset.tenant_id, | |||
| model_provider_name=dataset.embedding_model_provider, | |||
| model_name=dataset.embedding_model | |||
| provider=dataset.embedding_model_provider, | |||
| model_type=ModelType.TEXT_EMBEDDING, | |||
| model=dataset.embedding_model | |||
| ) | |||
| except Exception: | |||
| try: | |||
| embedding_model = ModelFactory.get_embedding_model( | |||
| tenant_id=dataset.tenant_id | |||
| embedding_model = model_manager.get_default_model_instance( | |||
| tenant_id=dataset.tenant_id, | |||
| model_type=ModelType.TEXT_EMBEDDING, | |||
| ) | |||
| dataset.embedding_model = embedding_model.name | |||
| dataset.embedding_model_provider = embedding_model.model_provider.provider_name | |||
| dataset.embedding_model = embedding_model.model | |||
| dataset.embedding_model_provider = embedding_model.provider | |||
| except Exception: | |||
| provider = Provider( | |||
| id='provider_id', | |||
| tenant_id=dataset.tenant_id, | |||
| @@ -87,7 +87,7 @@ class Config: | |||
| # ------------------------ | |||
| # General Configurations. | |||
| # ------------------------ | |||
| self.CURRENT_VERSION = "0.3.34" | |||
| self.CURRENT_VERSION = "0.4.0" | |||
| self.COMMIT_SHA = get_env('COMMIT_SHA') | |||
| self.EDITION = "SELF_HOSTED" | |||
| self.DEPLOY_ENV = get_env('DEPLOY_ENV') | |||
| @@ -18,7 +18,7 @@ from .auth import login, oauth, data_source_oauth, activate | |||
| from .datasets import datasets, datasets_document, datasets_segments, file, hit_testing, data_source | |||
| # Import workspace controllers | |||
| from .workspace import workspace, members, providers, model_providers, account, tool_providers, models | |||
| from .workspace import workspace, members, model_providers, account, tool_providers, models | |||
| # Import explore controllers | |||
| from .explore import installed_app, recommended_app, completion, conversation, message, parameter, saved_message, audio | |||
| @@ -4,6 +4,10 @@ import logging | |||
| from datetime import datetime | |||
| from flask_login import current_user | |||
| from core.model_manager import ModelManager | |||
| from core.model_runtime.entities.model_entities import ModelType | |||
| from core.provider_manager import ProviderManager | |||
| from libs.login import login_required | |||
| from flask_restful import Resource, reqparse, marshal_with, abort, inputs | |||
| from werkzeug.exceptions import Forbidden | |||
| @@ -13,9 +17,7 @@ from controllers.console import api | |||
| from controllers.console.app.error import AppNotFoundError, ProviderNotInitializeError | |||
| from controllers.console.setup import setup_required | |||
| from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check | |||
| from core.model_providers.error import ProviderTokenNotInitError, LLMBadRequestError | |||
| from core.model_providers.model_factory import ModelFactory | |||
| from core.model_providers.model_provider_factory import ModelProviderFactory | |||
| from core.errors.error import ProviderTokenNotInitError, LLMBadRequestError | |||
| from events.app_event import app_was_created, app_was_deleted | |||
| from fields.app_fields import app_pagination_fields, app_detail_fields, template_list_fields, \ | |||
| app_detail_fields_with_site | |||
| @@ -73,39 +75,41 @@ class AppListApi(Resource): | |||
| raise Forbidden() | |||
| try: | |||
| default_model = ModelFactory.get_text_generation_model( | |||
| tenant_id=current_user.current_tenant_id | |||
| provider_manager = ProviderManager() | |||
| default_model_entity = provider_manager.get_default_model( | |||
| tenant_id=current_user.current_tenant_id, | |||
| model_type=ModelType.LLM | |||
| ) | |||
| except (ProviderTokenNotInitError, LLMBadRequestError): | |||
| default_model = None | |||
| default_model_entity = None | |||
| except Exception as e: | |||
| logging.exception(e) | |||
| default_model = None | |||
| default_model_entity = None | |||
| if args['model_config'] is not None: | |||
| # validate config | |||
| model_config_dict = args['model_config'] | |||
| # get model provider | |||
| model_provider = ModelProviderFactory.get_preferred_model_provider( | |||
| current_user.current_tenant_id, | |||
| model_config_dict["model"]["provider"] | |||
| model_manager = ModelManager() | |||
| model_instance = model_manager.get_default_model_instance( | |||
| tenant_id=current_user.current_tenant_id, | |||
| model_type=ModelType.LLM | |||
| ) | |||
| if not model_provider: | |||
| if not default_model: | |||
| raise ProviderNotInitializeError( | |||
| f"No Default System Reasoning Model available. Please configure " | |||
| f"in the Settings -> Model Provider.") | |||
| else: | |||
| model_config_dict["model"]["provider"] = default_model.model_provider.provider_name | |||
| model_config_dict["model"]["name"] = default_model.name | |||
| if not model_instance: | |||
| raise ProviderNotInitializeError( | |||
| f"No Default System Reasoning Model available. Please configure " | |||
| f"in the Settings -> Model Provider.") | |||
| else: | |||
| model_config_dict["model"]["provider"] = model_instance.provider | |||
| model_config_dict["model"]["name"] = model_instance.model | |||
| model_configuration = AppModelConfigService.validate_configuration( | |||
| tenant_id=current_user.current_tenant_id, | |||
| account=current_user, | |||
| config=model_config_dict, | |||
| mode=args['mode'] | |||
| app_mode=args['mode'] | |||
| ) | |||
| app = App( | |||
| @@ -129,21 +133,27 @@ class AppListApi(Resource): | |||
| app_model_config = AppModelConfig(**model_config_template['model_config']) | |||
| # get model provider | |||
| model_provider = ModelProviderFactory.get_preferred_model_provider( | |||
| current_user.current_tenant_id, | |||
| app_model_config.model_dict["provider"] | |||
| ) | |||
| if not model_provider: | |||
| if not default_model: | |||
| raise ProviderNotInitializeError( | |||
| f"No Default System Reasoning Model available. Please configure " | |||
| f"in the Settings -> Model Provider.") | |||
| else: | |||
| model_dict = app_model_config.model_dict | |||
| model_dict['provider'] = default_model.model_provider.provider_name | |||
| model_dict['name'] = default_model.name | |||
| app_model_config.model = json.dumps(model_dict) | |||
| model_manager = ModelManager() | |||
| try: | |||
| model_instance = model_manager.get_default_model_instance( | |||
| tenant_id=current_user.current_tenant_id, | |||
| model_type=ModelType.LLM | |||
| ) | |||
| except ProviderTokenNotInitError: | |||
| raise ProviderNotInitializeError( | |||
| f"No Default System Reasoning Model available. Please configure " | |||
| f"in the Settings -> Model Provider.") | |||
| if not model_instance: | |||
| raise ProviderNotInitializeError( | |||
| f"No Default System Reasoning Model available. Please configure " | |||
| f"in the Settings -> Model Provider.") | |||
| else: | |||
| model_dict = app_model_config.model_dict | |||
| model_dict['provider'] = model_instance.provider | |||
| model_dict['name'] = model_instance.model | |||
| app_model_config.model = json.dumps(model_dict) | |||
| app.name = args['name'] | |||
| app.mode = args['mode'] | |||
| @@ -2,6 +2,8 @@ | |||
| import logging | |||
| from flask import request | |||
| from core.model_runtime.errors.invoke import InvokeError | |||
| from libs.login import login_required | |||
| from werkzeug.exceptions import InternalServerError | |||
| @@ -14,8 +16,7 @@ from controllers.console.app.error import AppUnavailableError, \ | |||
| UnsupportedAudioTypeError, ProviderNotSupportSpeechToTextError | |||
| from controllers.console.setup import setup_required | |||
| from controllers.console.wraps import account_initialization_required | |||
| from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \ | |||
| LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError | |||
| from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError | |||
| from flask_restful import Resource | |||
| from services.audio_service import AudioService | |||
| from services.errors.audio import NoAudioUploadedServiceError, AudioTooLargeServiceError, \ | |||
| @@ -56,8 +57,7 @@ class ChatMessageAudioApi(Resource): | |||
| raise ProviderQuotaExceededError() | |||
| except ModelCurrentlyNotSupportError: | |||
| raise ProviderModelCurrentlyNotSupportError() | |||
| except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, | |||
| LLMRateLimitError, LLMAuthorizationError) as e: | |||
| except InvokeError as e: | |||
| raise CompletionRequestError(str(e)) | |||
| except ValueError as e: | |||
| raise e | |||
| @@ -5,6 +5,10 @@ from typing import Generator, Union | |||
| import flask_login | |||
| from flask import Response, stream_with_context | |||
| from core.application_queue_manager import ApplicationQueueManager | |||
| from core.entities.application_entities import InvokeFrom | |||
| from core.model_runtime.errors.invoke import InvokeError | |||
| from libs.login import login_required | |||
| from werkzeug.exceptions import InternalServerError, NotFound | |||
| @@ -16,9 +20,7 @@ from controllers.console.app.error import ConversationCompletedError, AppUnavail | |||
| ProviderModelCurrentlyNotSupportError | |||
| from controllers.console.setup import setup_required | |||
| from controllers.console.wraps import account_initialization_required | |||
| from core.conversation_message_task import PubHandler | |||
| from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \ | |||
| LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError | |||
| from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError | |||
| from libs.helper import uuid_value | |||
| from flask_restful import Resource, reqparse | |||
| @@ -56,7 +58,7 @@ class CompletionMessageApi(Resource): | |||
| app_model=app_model, | |||
| user=account, | |||
| args=args, | |||
| from_source='console', | |||
| invoke_from=InvokeFrom.DEBUGGER, | |||
| streaming=streaming, | |||
| is_model_config_override=True | |||
| ) | |||
| @@ -75,8 +77,7 @@ class CompletionMessageApi(Resource): | |||
| raise ProviderQuotaExceededError() | |||
| except ModelCurrentlyNotSupportError: | |||
| raise ProviderModelCurrentlyNotSupportError() | |||
| except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, | |||
| LLMRateLimitError, LLMAuthorizationError) as e: | |||
| except InvokeError as e: | |||
| raise CompletionRequestError(str(e)) | |||
| except ValueError as e: | |||
| raise e | |||
| @@ -97,7 +98,7 @@ class CompletionMessageStopApi(Resource): | |||
| account = flask_login.current_user | |||
| PubHandler.stop(account, task_id) | |||
| ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, account.id) | |||
| return {'result': 'success'}, 200 | |||
| @@ -132,7 +133,7 @@ class ChatMessageApi(Resource): | |||
| app_model=app_model, | |||
| user=account, | |||
| args=args, | |||
| from_source='console', | |||
| invoke_from=InvokeFrom.DEBUGGER, | |||
| streaming=streaming, | |||
| is_model_config_override=True | |||
| ) | |||
| @@ -151,8 +152,7 @@ class ChatMessageApi(Resource): | |||
| raise ProviderQuotaExceededError() | |||
| except ModelCurrentlyNotSupportError: | |||
| raise ProviderModelCurrentlyNotSupportError() | |||
| except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, | |||
| LLMRateLimitError, LLMAuthorizationError) as e: | |||
| except InvokeError as e: | |||
| raise CompletionRequestError(str(e)) | |||
| except ValueError as e: | |||
| raise e | |||
| @@ -182,9 +182,8 @@ def compact_response(response: Union[dict, Generator]) -> Response: | |||
| yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n" | |||
| except ModelCurrentlyNotSupportError: | |||
| yield "data: " + json.dumps(api.handle_error(ProviderModelCurrentlyNotSupportError()).get_json()) + "\n\n" | |||
| except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, | |||
| LLMRateLimitError, LLMAuthorizationError) as e: | |||
| yield "data: " + json.dumps(api.handle_error(CompletionRequestError(str(e))).get_json()) + "\n\n" | |||
| except InvokeError as e: | |||
| yield "data: " + json.dumps(api.handle_error(CompletionRequestError(e.description)).get_json()) + "\n\n" | |||
| except ValueError as e: | |||
| yield "data: " + json.dumps(api.handle_error(e).get_json()) + "\n\n" | |||
| except Exception: | |||
| @@ -207,7 +206,7 @@ class ChatMessageStopApi(Resource): | |||
| account = flask_login.current_user | |||
| PubHandler.stop(account, task_id) | |||
| ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, account.id) | |||
| return {'result': 'success'}, 200 | |||
| @@ -1,4 +1,6 @@ | |||
| from flask_login import current_user | |||
| from core.model_runtime.errors.invoke import InvokeError | |||
| from libs.login import login_required | |||
| from flask_restful import Resource, reqparse | |||
| @@ -8,8 +10,7 @@ from controllers.console.app.error import ProviderNotInitializeError, ProviderQu | |||
| from controllers.console.setup import setup_required | |||
| from controllers.console.wraps import account_initialization_required | |||
| from core.generator.llm_generator import LLMGenerator | |||
| from core.model_providers.error import ProviderTokenNotInitError, QuotaExceededError, LLMBadRequestError, LLMAPIConnectionError, \ | |||
| LLMAPIUnavailableError, LLMRateLimitError, LLMAuthorizationError, ModelCurrentlyNotSupportError | |||
| from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError | |||
| class RuleGenerateApi(Resource): | |||
| @@ -36,8 +37,7 @@ class RuleGenerateApi(Resource): | |||
| raise ProviderQuotaExceededError() | |||
| except ModelCurrentlyNotSupportError: | |||
| raise ProviderModelCurrentlyNotSupportError() | |||
| except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, | |||
| LLMRateLimitError, LLMAuthorizationError) as e: | |||
| except InvokeError as e: | |||
| raise CompletionRequestError(str(e)) | |||
| return rules | |||
| @@ -14,8 +14,9 @@ from controllers.console.app.error import CompletionRequestError, ProviderNotIni | |||
| AppMoreLikeThisDisabledError, ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError | |||
| from controllers.console.setup import setup_required | |||
| from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check | |||
| from core.model_providers.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \ | |||
| ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError | |||
| from core.entities.application_entities import InvokeFrom | |||
| from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError | |||
| from core.model_runtime.errors.invoke import InvokeError | |||
| from libs.login import login_required | |||
| from fields.conversation_fields import message_detail_fields, annotation_fields | |||
| from libs.helper import uuid_value | |||
| @@ -208,7 +209,13 @@ class MessageMoreLikeThisApi(Resource): | |||
| app_model = _get_app(app_id, 'completion') | |||
| try: | |||
| response = CompletionService.generate_more_like_this(app_model, current_user, message_id, streaming) | |||
| response = CompletionService.generate_more_like_this( | |||
| app_model=app_model, | |||
| user=current_user, | |||
| message_id=message_id, | |||
| invoke_from=InvokeFrom.DEBUGGER, | |||
| streaming=streaming | |||
| ) | |||
| return compact_response(response) | |||
| except MessageNotExistsError: | |||
| raise NotFound("Message Not Exists.") | |||
| @@ -220,8 +227,7 @@ class MessageMoreLikeThisApi(Resource): | |||
| raise ProviderQuotaExceededError() | |||
| except ModelCurrentlyNotSupportError: | |||
| raise ProviderModelCurrentlyNotSupportError() | |||
| except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, | |||
| LLMRateLimitError, LLMAuthorizationError) as e: | |||
| except InvokeError as e: | |||
| raise CompletionRequestError(str(e)) | |||
| except ValueError as e: | |||
| raise e | |||
| @@ -249,8 +255,7 @@ def compact_response(response: Union[dict, Generator]) -> Response: | |||
| except ModelCurrentlyNotSupportError: | |||
| yield "data: " + json.dumps( | |||
| api.handle_error(ProviderModelCurrentlyNotSupportError()).get_json()) + "\n\n" | |||
| except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, | |||
| LLMRateLimitError, LLMAuthorizationError) as e: | |||
| except InvokeError as e: | |||
| yield "data: " + json.dumps(api.handle_error(CompletionRequestError(str(e))).get_json()) + "\n\n" | |||
| except ValueError as e: | |||
| yield "data: " + json.dumps(api.handle_error(e).get_json()) + "\n\n" | |||
| @@ -290,8 +295,7 @@ class MessageSuggestedQuestionApi(Resource): | |||
| raise ProviderQuotaExceededError() | |||
| except ModelCurrentlyNotSupportError: | |||
| raise ProviderModelCurrentlyNotSupportError() | |||
| except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, | |||
| LLMRateLimitError, LLMAuthorizationError) as e: | |||
| except InvokeError as e: | |||
| raise CompletionRequestError(str(e)) | |||
| except Exception: | |||
| logging.exception("internal server error.") | |||
| @@ -31,7 +31,7 @@ class ModelConfigResource(Resource): | |||
| tenant_id=current_user.current_tenant_id, | |||
| account=current_user, | |||
| config=request.json, | |||
| mode=app.mode | |||
| app_mode=app.mode | |||
| ) | |||
| new_app_model_config = AppModelConfig( | |||
| @@ -4,6 +4,8 @@ from flask import request, current_app | |||
| from flask_login import current_user | |||
| from controllers.console.apikey import api_key_list, api_key_fields | |||
| from core.model_runtime.entities.model_entities import ModelType | |||
| from core.provider_manager import ProviderManager | |||
| from libs.login import login_required | |||
| from flask_restful import Resource, reqparse, marshal, marshal_with | |||
| from werkzeug.exceptions import NotFound, Forbidden | |||
| @@ -14,8 +16,7 @@ from controllers.console.datasets.error import DatasetNameDuplicateError | |||
| from controllers.console.setup import setup_required | |||
| from controllers.console.wraps import account_initialization_required | |||
| from core.indexing_runner import IndexingRunner | |||
| from core.model_providers.error import LLMBadRequestError, ProviderTokenNotInitError | |||
| from core.model_providers.models.entity.model_params import ModelType | |||
| from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError | |||
| from fields.app_fields import related_app_list | |||
| from fields.dataset_fields import dataset_detail_fields, dataset_query_detail_fields | |||
| from fields.document_fields import document_status_fields | |||
| @@ -23,7 +24,6 @@ from extensions.ext_database import db | |||
| from models.dataset import DocumentSegment, Document | |||
| from models.model import UploadFile, ApiToken | |||
| from services.dataset_service import DatasetService, DocumentService | |||
| from services.provider_service import ProviderService | |||
| def _validate_name(name): | |||
| @@ -55,16 +55,20 @@ class DatasetListApi(Resource): | |||
| current_user.current_tenant_id, current_user) | |||
| # check embedding setting | |||
| provider_service = ProviderService() | |||
| valid_model_list = provider_service.get_valid_model_list(current_user.current_tenant_id, | |||
| ModelType.EMBEDDINGS.value) | |||
| # if len(valid_model_list) == 0: | |||
| # raise ProviderNotInitializeError( | |||
| # f"No Embedding Model available. Please configure a valid provider " | |||
| # f"in the Settings -> Model Provider.") | |||
| provider_manager = ProviderManager() | |||
| configurations = provider_manager.get_configurations( | |||
| tenant_id=current_user.current_tenant_id | |||
| ) | |||
| embedding_models = configurations.get_models( | |||
| model_type=ModelType.TEXT_EMBEDDING, | |||
| only_active=True | |||
| ) | |||
| model_names = [] | |||
| for valid_model in valid_model_list: | |||
| model_names.append(f"{valid_model['model_name']}:{valid_model['model_provider']['provider_name']}") | |||
| for embedding_model in embedding_models: | |||
| model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}") | |||
| data = marshal(datasets, dataset_detail_fields) | |||
| for item in data: | |||
| if item['indexing_technique'] == 'high_quality': | |||
| @@ -75,6 +79,7 @@ class DatasetListApi(Resource): | |||
| item['embedding_available'] = False | |||
| else: | |||
| item['embedding_available'] = True | |||
| response = { | |||
| 'data': data, | |||
| 'has_more': len(datasets) == limit, | |||
| @@ -130,13 +135,20 @@ class DatasetApi(Resource): | |||
| raise Forbidden(str(e)) | |||
| data = marshal(dataset, dataset_detail_fields) | |||
| # check embedding setting | |||
| provider_service = ProviderService() | |||
| # get valid model list | |||
| valid_model_list = provider_service.get_valid_model_list(current_user.current_tenant_id, | |||
| ModelType.EMBEDDINGS.value) | |||
| provider_manager = ProviderManager() | |||
| configurations = provider_manager.get_configurations( | |||
| tenant_id=current_user.current_tenant_id | |||
| ) | |||
| embedding_models = configurations.get_models( | |||
| model_type=ModelType.TEXT_EMBEDDING, | |||
| only_active=True | |||
| ) | |||
| model_names = [] | |||
| for valid_model in valid_model_list: | |||
| model_names.append(f"{valid_model['model_name']}:{valid_model['model_provider']['provider_name']}") | |||
| for embedding_model in embedding_models: | |||
| model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}") | |||
| if data['indexing_technique'] == 'high_quality': | |||
| item_model = f"{data['embedding_model']}:{data['embedding_model_provider']}" | |||
| if item_model in model_names: | |||
| @@ -2,8 +2,12 @@ | |||
| from datetime import datetime | |||
| from typing import List | |||
| from flask import request, current_app | |||
| from flask import request | |||
| from flask_login import current_user | |||
| from core.model_manager import ModelManager | |||
| from core.model_runtime.entities.model_entities import ModelType | |||
| from core.model_runtime.errors.invoke import InvokeAuthorizationError | |||
| from libs.login import login_required | |||
| from flask_restful import Resource, fields, marshal, marshal_with, reqparse | |||
| from sqlalchemy import desc, asc | |||
| @@ -18,9 +22,8 @@ from controllers.console.datasets.error import DocumentAlreadyFinishedError, Inv | |||
| from controllers.console.setup import setup_required | |||
| from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check | |||
| from core.indexing_runner import IndexingRunner | |||
| from core.model_providers.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError, \ | |||
| from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError, \ | |||
| LLMBadRequestError | |||
| from core.model_providers.model_factory import ModelFactory | |||
| from extensions.ext_redis import redis_client | |||
| from fields.document_fields import document_with_segments_fields, document_fields, \ | |||
| dataset_and_document_fields, document_status_fields | |||
| @@ -272,10 +275,12 @@ class DatasetInitApi(Resource): | |||
| args = parser.parse_args() | |||
| if args['indexing_technique'] == 'high_quality': | |||
| try: | |||
| ModelFactory.get_embedding_model( | |||
| tenant_id=current_user.current_tenant_id | |||
| model_manager = ModelManager() | |||
| model_manager.get_default_model_instance( | |||
| tenant_id=current_user.current_tenant_id, | |||
| model_type=ModelType.TEXT_EMBEDDING | |||
| ) | |||
| except LLMBadRequestError: | |||
| except InvokeAuthorizationError: | |||
| raise ProviderNotInitializeError( | |||
| f"No Embedding Model available. Please configure a valid provider " | |||
| f"in the Settings -> Model Provider.") | |||
| @@ -12,8 +12,9 @@ from controllers.console.app.error import ProviderNotInitializeError | |||
| from controllers.console.datasets.error import InvalidActionError, NoFileUploadedError, TooManyFilesError | |||
| from controllers.console.setup import setup_required | |||
| from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check | |||
| from core.model_providers.error import LLMBadRequestError, ProviderTokenNotInitError | |||
| from core.model_providers.model_factory import ModelFactory | |||
| from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError | |||
| from core.model_manager import ModelManager | |||
| from core.model_runtime.entities.model_entities import ModelType | |||
| from libs.login import login_required | |||
| from extensions.ext_database import db | |||
| from extensions.ext_redis import redis_client | |||
| @@ -133,10 +134,12 @@ class DatasetDocumentSegmentApi(Resource): | |||
| if dataset.indexing_technique == 'high_quality': | |||
| # check embedding model setting | |||
| try: | |||
| ModelFactory.get_embedding_model( | |||
| model_manager = ModelManager() | |||
| model_manager.get_model_instance( | |||
| tenant_id=current_user.current_tenant_id, | |||
| model_provider_name=dataset.embedding_model_provider, | |||
| model_name=dataset.embedding_model | |||
| provider=dataset.embedding_model_provider, | |||
| model_type=ModelType.TEXT_EMBEDDING, | |||
| model=dataset.embedding_model | |||
| ) | |||
| except LLMBadRequestError: | |||
| raise ProviderNotInitializeError( | |||
| @@ -219,10 +222,12 @@ class DatasetDocumentSegmentAddApi(Resource): | |||
| # check embedding model setting | |||
| if dataset.indexing_technique == 'high_quality': | |||
| try: | |||
| ModelFactory.get_embedding_model( | |||
| model_manager = ModelManager() | |||
| model_manager.get_model_instance( | |||
| tenant_id=current_user.current_tenant_id, | |||
| model_provider_name=dataset.embedding_model_provider, | |||
| model_name=dataset.embedding_model | |||
| provider=dataset.embedding_model_provider, | |||
| model_type=ModelType.TEXT_EMBEDDING, | |||
| model=dataset.embedding_model | |||
| ) | |||
| except LLMBadRequestError: | |||
| raise ProviderNotInitializeError( | |||
| @@ -269,10 +274,12 @@ class DatasetDocumentSegmentUpdateApi(Resource): | |||
| if dataset.indexing_technique == 'high_quality': | |||
| # check embedding model setting | |||
| try: | |||
| ModelFactory.get_embedding_model( | |||
| model_manager = ModelManager() | |||
| model_manager.get_model_instance( | |||
| tenant_id=current_user.current_tenant_id, | |||
| model_provider_name=dataset.embedding_model_provider, | |||
| model_name=dataset.embedding_model | |||
| provider=dataset.embedding_model_provider, | |||
| model_type=ModelType.TEXT_EMBEDDING, | |||
| model=dataset.embedding_model | |||
| ) | |||
| except LLMBadRequestError: | |||
| raise ProviderNotInitializeError( | |||
| @@ -12,7 +12,7 @@ from controllers.console.app.error import ProviderNotInitializeError, ProviderQu | |||
| from controllers.console.datasets.error import HighQualityDatasetOnlyError, DatasetNotInitializedError | |||
| from controllers.console.setup import setup_required | |||
| from controllers.console.wraps import account_initialization_required | |||
| from core.model_providers.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError, \ | |||
| from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError, \ | |||
| LLMBadRequestError | |||
| from fields.hit_testing_fields import hit_testing_record_fields | |||
| from services.dataset_service import DatasetService | |||
| @@ -11,8 +11,8 @@ from controllers.console.app.error import AppUnavailableError, ProviderNotInitia | |||
| NoAudioUploadedError, AudioTooLargeError, \ | |||
| UnsupportedAudioTypeError, ProviderNotSupportSpeechToTextError | |||
| from controllers.console.explore.wraps import InstalledAppResource | |||
| from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \ | |||
| LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError | |||
| from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError | |||
| from core.model_runtime.errors.invoke import InvokeError | |||
| from services.audio_service import AudioService | |||
| from services.errors.audio import NoAudioUploadedServiceError, AudioTooLargeServiceError, \ | |||
| UnsupportedAudioTypeServiceError, ProviderNotSupportSpeechToTextServiceError | |||
| @@ -53,8 +53,7 @@ class ChatAudioApi(InstalledAppResource): | |||
| raise ProviderQuotaExceededError() | |||
| except ModelCurrentlyNotSupportError: | |||
| raise ProviderModelCurrentlyNotSupportError() | |||
| except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, | |||
| LLMRateLimitError, LLMAuthorizationError) as e: | |||
| except InvokeError as e: | |||
| raise CompletionRequestError(str(e)) | |||
| except ValueError as e: | |||
| raise e | |||
| @@ -15,9 +15,10 @@ from controllers.console.app.error import ConversationCompletedError, AppUnavail | |||
| ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, CompletionRequestError | |||
| from controllers.console.explore.error import NotCompletionAppError, NotChatAppError | |||
| from controllers.console.explore.wraps import InstalledAppResource | |||
| from core.conversation_message_task import PubHandler | |||
| from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \ | |||
| LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError | |||
| from core.application_queue_manager import ApplicationQueueManager | |||
| from core.entities.application_entities import InvokeFrom | |||
| from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError | |||
| from core.model_runtime.errors.invoke import InvokeError | |||
| from extensions.ext_database import db | |||
| from libs.helper import uuid_value | |||
| from services.completion_service import CompletionService | |||
| @@ -50,7 +51,7 @@ class CompletionApi(InstalledAppResource): | |||
| app_model=app_model, | |||
| user=current_user, | |||
| args=args, | |||
| from_source='console', | |||
| invoke_from=InvokeFrom.EXPLORE, | |||
| streaming=streaming | |||
| ) | |||
| @@ -68,8 +69,7 @@ class CompletionApi(InstalledAppResource): | |||
| raise ProviderQuotaExceededError() | |||
| except ModelCurrentlyNotSupportError: | |||
| raise ProviderModelCurrentlyNotSupportError() | |||
| except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, | |||
| LLMRateLimitError, LLMAuthorizationError) as e: | |||
| except InvokeError as e: | |||
| raise CompletionRequestError(str(e)) | |||
| except ValueError as e: | |||
| raise e | |||
| @@ -84,7 +84,7 @@ class CompletionStopApi(InstalledAppResource): | |||
| if app_model.mode != 'completion': | |||
| raise NotCompletionAppError() | |||
| PubHandler.stop(current_user, task_id) | |||
| ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id) | |||
| return {'result': 'success'}, 200 | |||
| @@ -115,7 +115,7 @@ class ChatApi(InstalledAppResource): | |||
| app_model=app_model, | |||
| user=current_user, | |||
| args=args, | |||
| from_source='console', | |||
| invoke_from=InvokeFrom.EXPLORE, | |||
| streaming=streaming | |||
| ) | |||
| @@ -133,8 +133,7 @@ class ChatApi(InstalledAppResource): | |||
| raise ProviderQuotaExceededError() | |||
| except ModelCurrentlyNotSupportError: | |||
| raise ProviderModelCurrentlyNotSupportError() | |||
| except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, | |||
| LLMRateLimitError, LLMAuthorizationError) as e: | |||
| except InvokeError as e: | |||
| raise CompletionRequestError(str(e)) | |||
| except ValueError as e: | |||
| raise e | |||
| @@ -149,7 +148,7 @@ class ChatStopApi(InstalledAppResource): | |||
| if app_model.mode != 'chat': | |||
| raise NotChatAppError() | |||
| PubHandler.stop(current_user, task_id) | |||
| ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id) | |||
| return {'result': 'success'}, 200 | |||
| @@ -175,8 +174,7 @@ def compact_response(response: Union[dict, Generator]) -> Response: | |||
| yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n" | |||
| except ModelCurrentlyNotSupportError: | |||
| yield "data: " + json.dumps(api.handle_error(ProviderModelCurrentlyNotSupportError()).get_json()) + "\n\n" | |||
| except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, | |||
| LLMRateLimitError, LLMAuthorizationError) as e: | |||
| except InvokeError as e: | |||
| yield "data: " + json.dumps(api.handle_error(CompletionRequestError(str(e))).get_json()) + "\n\n" | |||
| except ValueError as e: | |||
| yield "data: " + json.dumps(api.handle_error(e).get_json()) + "\n\n" | |||
| @@ -5,7 +5,7 @@ from typing import Generator, Union | |||
| from flask import stream_with_context, Response | |||
| from flask_login import current_user | |||
| from flask_restful import reqparse, fields, marshal_with | |||
| from flask_restful import reqparse, marshal_with | |||
| from flask_restful.inputs import int_range | |||
| from werkzeug.exceptions import NotFound, InternalServerError | |||
| @@ -13,12 +13,14 @@ import services | |||
| from controllers.console import api | |||
| from controllers.console.app.error import AppMoreLikeThisDisabledError, ProviderNotInitializeError, \ | |||
| ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, CompletionRequestError | |||
| from controllers.console.explore.error import NotCompletionAppError, AppSuggestedQuestionsAfterAnswerDisabledError | |||
| from controllers.console.explore.error import NotCompletionAppError, AppSuggestedQuestionsAfterAnswerDisabledError, \ | |||
| NotChatAppError | |||
| from controllers.console.explore.wraps import InstalledAppResource | |||
| from core.model_providers.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \ | |||
| ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError | |||
| from core.entities.application_entities import InvokeFrom | |||
| from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError | |||
| from core.model_runtime.errors.invoke import InvokeError | |||
| from fields.message_fields import message_infinite_scroll_pagination_fields | |||
| from libs.helper import uuid_value, TimestampField | |||
| from libs.helper import uuid_value | |||
| from services.completion_service import CompletionService | |||
| from services.errors.app import MoreLikeThisDisabledError | |||
| from services.errors.conversation import ConversationNotExistsError | |||
| @@ -83,7 +85,13 @@ class MessageMoreLikeThisApi(InstalledAppResource): | |||
| streaming = args['response_mode'] == 'streaming' | |||
| try: | |||
| response = CompletionService.generate_more_like_this(app_model, current_user, message_id, streaming) | |||
| response = CompletionService.generate_more_like_this( | |||
| app_model=app_model, | |||
| user=current_user, | |||
| message_id=message_id, | |||
| invoke_from=InvokeFrom.EXPLORE, | |||
| streaming=streaming | |||
| ) | |||
| return compact_response(response) | |||
| except MessageNotExistsError: | |||
| raise NotFound("Message Not Exists.") | |||
| @@ -95,8 +103,7 @@ class MessageMoreLikeThisApi(InstalledAppResource): | |||
| raise ProviderQuotaExceededError() | |||
| except ModelCurrentlyNotSupportError: | |||
| raise ProviderModelCurrentlyNotSupportError() | |||
| except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, | |||
| LLMRateLimitError, LLMAuthorizationError) as e: | |||
| except InvokeError as e: | |||
| raise CompletionRequestError(str(e)) | |||
| except ValueError as e: | |||
| raise e | |||
| @@ -123,8 +130,7 @@ def compact_response(response: Union[dict, Generator]) -> Response: | |||
| yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n" | |||
| except ModelCurrentlyNotSupportError: | |||
| yield "data: " + json.dumps(api.handle_error(ProviderModelCurrentlyNotSupportError()).get_json()) + "\n\n" | |||
| except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, | |||
| LLMRateLimitError, LLMAuthorizationError) as e: | |||
| except InvokeError as e: | |||
| yield "data: " + json.dumps(api.handle_error(CompletionRequestError(str(e))).get_json()) + "\n\n" | |||
| except ValueError as e: | |||
| yield "data: " + json.dumps(api.handle_error(e).get_json()) + "\n\n" | |||
| @@ -162,8 +168,7 @@ class MessageSuggestedQuestionApi(InstalledAppResource): | |||
| raise ProviderQuotaExceededError() | |||
| except ModelCurrentlyNotSupportError: | |||
| raise ProviderModelCurrentlyNotSupportError() | |||
| except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, | |||
| LLMRateLimitError, LLMAuthorizationError) as e: | |||
| except InvokeError as e: | |||
| raise CompletionRequestError(str(e)) | |||
| except Exception: | |||
| logging.exception("internal server error.") | |||
| @@ -11,8 +11,8 @@ from controllers.console.app.error import AppUnavailableError, ProviderNotInitia | |||
| NoAudioUploadedError, AudioTooLargeError, \ | |||
| UnsupportedAudioTypeError, ProviderNotSupportSpeechToTextError | |||
| from controllers.console.universal_chat.wraps import UniversalChatResource | |||
| from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \ | |||
| LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError | |||
| from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError | |||
| from core.model_runtime.errors.invoke import InvokeError | |||
| from services.audio_service import AudioService | |||
| from services.errors.audio import NoAudioUploadedServiceError, AudioTooLargeServiceError, \ | |||
| UnsupportedAudioTypeServiceError, ProviderNotSupportSpeechToTextServiceError | |||
| @@ -53,8 +53,7 @@ class UniversalChatAudioApi(UniversalChatResource): | |||
| raise ProviderQuotaExceededError() | |||
| except ModelCurrentlyNotSupportError: | |||
| raise ProviderModelCurrentlyNotSupportError() | |||
| except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, | |||
| LLMRateLimitError, LLMAuthorizationError) as e: | |||
| except InvokeError as e: | |||
| raise CompletionRequestError(str(e)) | |||
| except ValueError as e: | |||
| raise e | |||
| @@ -12,9 +12,10 @@ from controllers.console import api | |||
| from controllers.console.app.error import ConversationCompletedError, AppUnavailableError, ProviderNotInitializeError, \ | |||
| ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, CompletionRequestError | |||
| from controllers.console.universal_chat.wraps import UniversalChatResource | |||
| from core.conversation_message_task import PubHandler | |||
| from core.model_providers.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError, \ | |||
| LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, LLMRateLimitError, LLMAuthorizationError | |||
| from core.application_queue_manager import ApplicationQueueManager | |||
| from core.entities.application_entities import InvokeFrom | |||
| from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError | |||
| from core.model_runtime.errors.invoke import InvokeError | |||
| from libs.helper import uuid_value | |||
| from services.completion_service import CompletionService | |||
| @@ -68,7 +69,7 @@ class UniversalChatApi(UniversalChatResource): | |||
| app_model=app_model, | |||
| user=current_user, | |||
| args=args, | |||
| from_source='console', | |||
| invoke_from=InvokeFrom.EXPLORE, | |||
| streaming=True, | |||
| is_model_config_override=True, | |||
| ) | |||
| @@ -87,8 +88,7 @@ class UniversalChatApi(UniversalChatResource): | |||
| raise ProviderQuotaExceededError() | |||
| except ModelCurrentlyNotSupportError: | |||
| raise ProviderModelCurrentlyNotSupportError() | |||
| except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, | |||
| LLMRateLimitError, LLMAuthorizationError) as e: | |||
| except InvokeError as e: | |||
| raise CompletionRequestError(str(e)) | |||
| except ValueError as e: | |||
| raise e | |||
| @@ -99,7 +99,7 @@ class UniversalChatApi(UniversalChatResource): | |||
| class UniversalChatStopApi(UniversalChatResource): | |||
| def post(self, universal_app, task_id): | |||
| PubHandler.stop(current_user, task_id) | |||
| ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id) | |||
| return {'result': 'success'}, 200 | |||
| @@ -125,8 +125,7 @@ def compact_response(response: Union[dict, Generator]) -> Response: | |||
| yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n" | |||
| except ModelCurrentlyNotSupportError: | |||
| yield "data: " + json.dumps(api.handle_error(ProviderModelCurrentlyNotSupportError()).get_json()) + "\n\n" | |||
| except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, | |||
| LLMRateLimitError, LLMAuthorizationError) as e: | |||
| except InvokeError as e: | |||
| yield "data: " + json.dumps(api.handle_error(CompletionRequestError(str(e))).get_json()) + "\n\n" | |||
| except ValueError as e: | |||
| yield "data: " + json.dumps(api.handle_error(e).get_json()) + "\n\n" | |||
| @@ -12,8 +12,8 @@ from controllers.console.app.error import ProviderNotInitializeError, \ | |||
| ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, CompletionRequestError | |||
| from controllers.console.explore.error import AppSuggestedQuestionsAfterAnswerDisabledError | |||
| from controllers.console.universal_chat.wraps import UniversalChatResource | |||
| from core.model_providers.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \ | |||
| ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError | |||
| from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError | |||
| from core.model_runtime.errors.invoke import InvokeError | |||
| from libs.helper import uuid_value, TimestampField | |||
| from services.errors.conversation import ConversationNotExistsError | |||
| from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError | |||
| @@ -132,8 +132,7 @@ class UniversalChatMessageSuggestedQuestionApi(UniversalChatResource): | |||
| raise ProviderQuotaExceededError() | |||
| except ModelCurrentlyNotSupportError: | |||
| raise ProviderModelCurrentlyNotSupportError() | |||
| except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, | |||
| LLMRateLimitError, LLMAuthorizationError) as e: | |||
| except InvokeError as e: | |||
| raise CompletionRequestError(str(e)) | |||
| except Exception: | |||
| logging.exception("internal server error.") | |||
| @@ -1,16 +1,19 @@ | |||
| import io | |||
| from flask import send_file | |||
| from flask_login import current_user | |||
| from libs.login import login_required | |||
| from flask_restful import Resource, reqparse | |||
| from werkzeug.exceptions import Forbidden | |||
| from controllers.console import api | |||
| from controllers.console.app.error import ProviderNotInitializeError | |||
| from controllers.console.setup import setup_required | |||
| from controllers.console.wraps import account_initialization_required | |||
| from core.model_providers.error import LLMBadRequestError | |||
| from core.model_providers.providers.base import CredentialsValidateFailedError | |||
| from services.provider_service import ProviderService | |||
| from core.model_runtime.entities.model_entities import ModelType | |||
| from core.model_runtime.errors.validate import CredentialsValidateFailedError | |||
| from core.model_runtime.utils.encoders import jsonable_encoder | |||
| from libs.login import login_required | |||
| from services.billing_service import BillingService | |||
| from services.model_provider_service import ModelProviderService | |||
| class ModelProviderListApi(Resource): | |||
| @@ -22,13 +25,36 @@ class ModelProviderListApi(Resource): | |||
| tenant_id = current_user.current_tenant_id | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('model_type', type=str, required=False, nullable=True, location='args') | |||
| parser.add_argument('model_type', type=str, required=False, nullable=True, | |||
| choices=[mt.value for mt in ModelType], location='args') | |||
| args = parser.parse_args() | |||
| provider_service = ProviderService() | |||
| provider_list = provider_service.get_provider_list(tenant_id=tenant_id, model_type=args.get('model_type')) | |||
| model_provider_service = ModelProviderService() | |||
| provider_list = model_provider_service.get_provider_list( | |||
| tenant_id=tenant_id, | |||
| model_type=args.get('model_type') | |||
| ) | |||
| return jsonable_encoder({"data": provider_list}) | |||
| class ModelProviderCredentialApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| def get(self, provider: str): | |||
| tenant_id = current_user.current_tenant_id | |||
| model_provider_service = ModelProviderService() | |||
| credentials = model_provider_service.get_provider_credentials( | |||
| tenant_id=tenant_id, | |||
| provider=provider | |||
| ) | |||
| return provider_list | |||
| return { | |||
| "credentials": credentials | |||
| } | |||
| class ModelProviderValidateApi(Resource): | |||
| @@ -36,21 +62,24 @@ class ModelProviderValidateApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| def post(self, provider_name: str): | |||
| def post(self, provider: str): | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('config', type=dict, required=True, nullable=False, location='json') | |||
| parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json') | |||
| args = parser.parse_args() | |||
| provider_service = ProviderService() | |||
| tenant_id = current_user.current_tenant_id | |||
| model_provider_service = ModelProviderService() | |||
| result = True | |||
| error = None | |||
| try: | |||
| provider_service.custom_provider_config_validate( | |||
| provider_name=provider_name, | |||
| config=args['config'] | |||
| model_provider_service.provider_credentials_validate( | |||
| tenant_id=tenant_id, | |||
| provider=provider, | |||
| credentials=args['credentials'] | |||
| ) | |||
| except CredentialsValidateFailedError as ex: | |||
| result = False | |||
| @@ -64,26 +93,26 @@ class ModelProviderValidateApi(Resource): | |||
| return response | |||
| class ModelProviderUpdateApi(Resource): | |||
| class ModelProviderApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| def post(self, provider_name: str): | |||
| def post(self, provider: str): | |||
| if current_user.current_tenant.current_role not in ['admin', 'owner']: | |||
| raise Forbidden() | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('config', type=dict, required=True, nullable=False, location='json') | |||
| parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json') | |||
| args = parser.parse_args() | |||
| provider_service = ProviderService() | |||
| model_provider_service = ModelProviderService() | |||
| try: | |||
| provider_service.save_custom_provider_config( | |||
| model_provider_service.save_provider_credentials( | |||
| tenant_id=current_user.current_tenant_id, | |||
| provider_name=provider_name, | |||
| config=args['config'] | |||
| provider=provider, | |||
| credentials=args['credentials'] | |||
| ) | |||
| except CredentialsValidateFailedError as ex: | |||
| raise ValueError(str(ex)) | |||
| @@ -93,109 +122,36 @@ class ModelProviderUpdateApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| def delete(self, provider_name: str): | |||
| def delete(self, provider: str): | |||
| if current_user.current_tenant.current_role not in ['admin', 'owner']: | |||
| raise Forbidden() | |||
| provider_service = ProviderService() | |||
| provider_service.delete_custom_provider( | |||
| model_provider_service = ModelProviderService() | |||
| model_provider_service.remove_provider_credentials( | |||
| tenant_id=current_user.current_tenant_id, | |||
| provider_name=provider_name | |||
| provider=provider | |||
| ) | |||
| return {'result': 'success'}, 204 | |||
| class ModelProviderModelValidateApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| def post(self, provider_name: str): | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('model_name', type=str, required=True, nullable=False, location='json') | |||
| parser.add_argument('model_type', type=str, required=True, nullable=False, | |||
| choices=['text-generation', 'embeddings', 'speech2text', 'reranking'], location='json') | |||
| parser.add_argument('config', type=dict, required=True, nullable=False, location='json') | |||
| args = parser.parse_args() | |||
| provider_service = ProviderService() | |||
| result = True | |||
| error = None | |||
| try: | |||
| provider_service.custom_provider_model_config_validate( | |||
| provider_name=provider_name, | |||
| model_name=args['model_name'], | |||
| model_type=args['model_type'], | |||
| config=args['config'] | |||
| ) | |||
| except CredentialsValidateFailedError as ex: | |||
| result = False | |||
| error = str(ex) | |||
| response = {'result': 'success' if result else 'error'} | |||
| if not result: | |||
| response['error'] = error | |||
| return response | |||
| class ModelProviderModelUpdateApi(Resource): | |||
| class ModelProviderIconApi(Resource): | |||
| """ | |||
| Get model provider icon | |||
| """ | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| def post(self, provider_name: str): | |||
| if current_user.current_tenant.current_role not in ['admin', 'owner']: | |||
| raise Forbidden() | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('model_name', type=str, required=True, nullable=False, location='json') | |||
| parser.add_argument('model_type', type=str, required=True, nullable=False, | |||
| choices=['text-generation', 'embeddings', 'speech2text', 'reranking'], location='json') | |||
| parser.add_argument('config', type=dict, required=True, nullable=False, location='json') | |||
| args = parser.parse_args() | |||
| provider_service = ProviderService() | |||
| try: | |||
| provider_service.add_or_save_custom_provider_model_config( | |||
| tenant_id=current_user.current_tenant_id, | |||
| provider_name=provider_name, | |||
| model_name=args['model_name'], | |||
| model_type=args['model_type'], | |||
| config=args['config'] | |||
| ) | |||
| except CredentialsValidateFailedError as ex: | |||
| raise ValueError(str(ex)) | |||
| return {'result': 'success'}, 200 | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| def delete(self, provider_name: str): | |||
| if current_user.current_tenant.current_role not in ['admin', 'owner']: | |||
| raise Forbidden() | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('model_name', type=str, required=True, nullable=False, location='args') | |||
| parser.add_argument('model_type', type=str, required=True, nullable=False, | |||
| choices=['text-generation', 'embeddings', 'speech2text', 'reranking'], location='args') | |||
| args = parser.parse_args() | |||
| provider_service = ProviderService() | |||
| provider_service.delete_custom_provider_model( | |||
| tenant_id=current_user.current_tenant_id, | |||
| provider_name=provider_name, | |||
| model_name=args['model_name'], | |||
| model_type=args['model_type'] | |||
| def get(self, provider: str, icon_type: str, lang: str): | |||
| model_provider_service = ModelProviderService() | |||
| icon, mimetype = model_provider_service.get_model_provider_icon( | |||
| provider=provider, | |||
| icon_type=icon_type, | |||
| lang=lang | |||
| ) | |||
| return {'result': 'success'}, 204 | |||
| return send_file(io.BytesIO(icon), mimetype=mimetype) | |||
| class PreferredProviderTypeUpdateApi(Resource): | |||
| @@ -203,71 +159,36 @@ class PreferredProviderTypeUpdateApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| def post(self, provider_name: str): | |||
| def post(self, provider: str): | |||
| if current_user.current_tenant.current_role not in ['admin', 'owner']: | |||
| raise Forbidden() | |||
| tenant_id = current_user.current_tenant_id | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('preferred_provider_type', type=str, required=True, nullable=False, | |||
| choices=['system', 'custom'], location='json') | |||
| args = parser.parse_args() | |||
| provider_service = ProviderService() | |||
| provider_service.switch_preferred_provider( | |||
| tenant_id=current_user.current_tenant_id, | |||
| provider_name=provider_name, | |||
| model_provider_service = ModelProviderService() | |||
| model_provider_service.switch_preferred_provider( | |||
| tenant_id=tenant_id, | |||
| provider=provider, | |||
| preferred_provider_type=args['preferred_provider_type'] | |||
| ) | |||
| return {'result': 'success'} | |||
| class ModelProviderModelParameterRuleApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| def get(self, provider_name: str): | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('model_name', type=str, required=True, nullable=False, location='args') | |||
| args = parser.parse_args() | |||
| provider_service = ProviderService() | |||
| try: | |||
| parameter_rules = provider_service.get_model_parameter_rules( | |||
| tenant_id=current_user.current_tenant_id, | |||
| model_provider_name=provider_name, | |||
| model_name=args['model_name'], | |||
| model_type='text-generation' | |||
| ) | |||
| except LLMBadRequestError: | |||
| raise ProviderNotInitializeError( | |||
| f"Current Text Generation Model is invalid. Please switch to the available model.") | |||
| rules = { | |||
| k: { | |||
| 'enabled': v.enabled, | |||
| 'min': v.min, | |||
| 'max': v.max, | |||
| 'default': v.default, | |||
| 'precision': v.precision | |||
| } | |||
| for k, v in vars(parameter_rules).items() | |||
| } | |||
| return rules | |||
| class ModelProviderPaymentCheckoutUrlApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| def get(self, provider_name: str): | |||
| if provider_name != 'anthropic': | |||
| raise ValueError(f'provider name {provider_name} is invalid') | |||
| def get(self, provider: str): | |||
| if provider != 'anthropic': | |||
| raise ValueError(f'provider name {provider} is invalid') | |||
| data = BillingService.get_model_provider_payment_link(provider_name=provider_name, | |||
| data = BillingService.get_model_provider_payment_link(provider_name=provider, | |||
| tenant_id=current_user.current_tenant_id, | |||
| account_id=current_user.id) | |||
| return data | |||
| @@ -277,11 +198,11 @@ class ModelProviderFreeQuotaSubmitApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| def post(self, provider_name: str): | |||
| provider_service = ProviderService() | |||
| result = provider_service.free_quota_submit( | |||
| def post(self, provider: str): | |||
| model_provider_service = ModelProviderService() | |||
| result = model_provider_service.free_quota_submit( | |||
| tenant_id=current_user.current_tenant_id, | |||
| provider_name=provider_name | |||
| provider=provider | |||
| ) | |||
| return result | |||
| @@ -291,15 +212,15 @@ class ModelProviderFreeQuotaQualificationVerifyApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| def get(self, provider_name: str): | |||
| def get(self, provider: str): | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('token', type=str, required=False, nullable=True, location='args') | |||
| args = parser.parse_args() | |||
| provider_service = ProviderService() | |||
| result = provider_service.free_quota_qualification_verify( | |||
| model_provider_service = ModelProviderService() | |||
| result = model_provider_service.free_quota_qualification_verify( | |||
| tenant_id=current_user.current_tenant_id, | |||
| provider_name=provider_name, | |||
| provider=provider, | |||
| token=args['token'] | |||
| ) | |||
| @@ -307,19 +228,18 @@ class ModelProviderFreeQuotaQualificationVerifyApi(Resource): | |||
| api.add_resource(ModelProviderListApi, '/workspaces/current/model-providers') | |||
| api.add_resource(ModelProviderValidateApi, '/workspaces/current/model-providers/<string:provider_name>/validate') | |||
| api.add_resource(ModelProviderUpdateApi, '/workspaces/current/model-providers/<string:provider_name>') | |||
| api.add_resource(ModelProviderModelValidateApi, | |||
| '/workspaces/current/model-providers/<string:provider_name>/models/validate') | |||
| api.add_resource(ModelProviderModelUpdateApi, | |||
| '/workspaces/current/model-providers/<string:provider_name>/models') | |||
| api.add_resource(ModelProviderCredentialApi, '/workspaces/current/model-providers/<string:provider>/credentials') | |||
| api.add_resource(ModelProviderValidateApi, '/workspaces/current/model-providers/<string:provider>/credentials/validate') | |||
| api.add_resource(ModelProviderApi, '/workspaces/current/model-providers/<string:provider>') | |||
| api.add_resource(ModelProviderIconApi, '/workspaces/current/model-providers/<string:provider>/' | |||
| '<string:icon_type>/<string:lang>') | |||
| api.add_resource(PreferredProviderTypeUpdateApi, | |||
| '/workspaces/current/model-providers/<string:provider_name>/preferred-provider-type') | |||
| api.add_resource(ModelProviderModelParameterRuleApi, | |||
| '/workspaces/current/model-providers/<string:provider_name>/models/parameter-rules') | |||
| '/workspaces/current/model-providers/<string:provider>/preferred-provider-type') | |||
| api.add_resource(ModelProviderPaymentCheckoutUrlApi, | |||
| '/workspaces/current/model-providers/<string:provider_name>/checkout-url') | |||
| '/workspaces/current/model-providers/<string:provider>/checkout-url') | |||
| api.add_resource(ModelProviderFreeQuotaSubmitApi, | |||
| '/workspaces/current/model-providers/<string:provider_name>/free-quota-submit') | |||
| '/workspaces/current/model-providers/<string:provider>/free-quota-submit') | |||
| api.add_resource(ModelProviderFreeQuotaQualificationVerifyApi, | |||
| '/workspaces/current/model-providers/<string:provider_name>/free-quota-qualification-verify') | |||
| '/workspaces/current/model-providers/<string:provider>/free-quota-qualification-verify') | |||
| @@ -1,16 +1,17 @@ | |||
| import logging | |||
| from flask_login import current_user | |||
| from libs.login import login_required | |||
| from flask_restful import Resource, reqparse | |||
| from flask_restful import reqparse, Resource | |||
| from werkzeug.exceptions import Forbidden | |||
| from controllers.console import api | |||
| from controllers.console.setup import setup_required | |||
| from controllers.console.wraps import account_initialization_required | |||
| from core.model_providers.model_provider_factory import ModelProviderFactory | |||
| from core.model_providers.models.entity.model_params import ModelType | |||
| from models.provider import ProviderType | |||
| from services.provider_service import ProviderService | |||
| from core.model_runtime.entities.model_entities import ModelType | |||
| from core.model_runtime.errors.validate import CredentialsValidateFailedError | |||
| from core.model_runtime.utils.encoders import jsonable_encoder | |||
| from libs.login import login_required | |||
| from services.model_provider_service import ModelProviderService | |||
| class DefaultModelApi(Resource): | |||
| @@ -21,52 +22,20 @@ class DefaultModelApi(Resource): | |||
| def get(self): | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('model_type', type=str, required=True, nullable=False, | |||
| choices=['text-generation', 'embeddings', 'speech2text', 'reranking'], location='args') | |||
| choices=[mt.value for mt in ModelType], location='args') | |||
| args = parser.parse_args() | |||
| tenant_id = current_user.current_tenant_id | |||
| provider_service = ProviderService() | |||
| default_model = provider_service.get_default_model_of_model_type( | |||
| model_provider_service = ModelProviderService() | |||
| default_model_entity = model_provider_service.get_default_model_of_model_type( | |||
| tenant_id=tenant_id, | |||
| model_type=args['model_type'] | |||
| ) | |||
| if not default_model: | |||
| return None | |||
| model_provider = ModelProviderFactory.get_preferred_model_provider( | |||
| tenant_id, | |||
| default_model.provider_name | |||
| ) | |||
| if not model_provider: | |||
| return { | |||
| 'model_name': default_model.model_name, | |||
| 'model_type': default_model.model_type, | |||
| 'model_provider': { | |||
| 'provider_name': default_model.provider_name | |||
| } | |||
| } | |||
| provider = model_provider.provider | |||
| rst = { | |||
| 'model_name': default_model.model_name, | |||
| 'model_type': default_model.model_type, | |||
| 'model_provider': { | |||
| 'provider_name': provider.provider_name, | |||
| 'provider_type': provider.provider_type | |||
| } | |||
| } | |||
| model_provider_rules = ModelProviderFactory.get_provider_rule(default_model.provider_name) | |||
| if provider.provider_type == ProviderType.SYSTEM.value: | |||
| rst['model_provider']['quota_type'] = provider.quota_type | |||
| rst['model_provider']['quota_unit'] = model_provider_rules['system_config']['quota_unit'] | |||
| rst['model_provider']['quota_limit'] = provider.quota_limit | |||
| rst['model_provider']['quota_used'] = provider.quota_used | |||
| return rst | |||
| return jsonable_encoder({ | |||
| "data": default_model_entity | |||
| }) | |||
| @setup_required | |||
| @login_required | |||
| @@ -76,15 +45,26 @@ class DefaultModelApi(Resource): | |||
| parser.add_argument('model_settings', type=list, required=True, nullable=False, location='json') | |||
| args = parser.parse_args() | |||
| provider_service = ProviderService() | |||
| tenant_id = current_user.current_tenant_id | |||
| model_provider_service = ModelProviderService() | |||
| model_settings = args['model_settings'] | |||
| for model_setting in model_settings: | |||
| if 'model_type' not in model_setting or model_setting['model_type'] not in [mt.value for mt in ModelType]: | |||
| raise ValueError('invalid model type') | |||
| if 'provider' not in model_setting: | |||
| continue | |||
| if 'model' not in model_setting: | |||
| raise ValueError('invalid model') | |||
| try: | |||
| provider_service.update_default_model_of_model_type( | |||
| tenant_id=current_user.current_tenant_id, | |||
| model_provider_service.update_default_model_of_model_type( | |||
| tenant_id=tenant_id, | |||
| model_type=model_setting['model_type'], | |||
| provider_name=model_setting['provider_name'], | |||
| model_name=model_setting['model_name'] | |||
| provider=model_setting['provider'], | |||
| model=model_setting['model'] | |||
| ) | |||
| except Exception: | |||
| logging.warning(f"{model_setting['model_type']} save error") | |||
| @@ -92,22 +72,198 @@ class DefaultModelApi(Resource): | |||
| return {'result': 'success'} | |||
| class ValidModelApi(Resource): | |||
| class ModelProviderModelApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| def get(self, provider): | |||
| tenant_id = current_user.current_tenant_id | |||
| model_provider_service = ModelProviderService() | |||
| models = model_provider_service.get_models_by_provider( | |||
| tenant_id=tenant_id, | |||
| provider=provider | |||
| ) | |||
| return jsonable_encoder({ | |||
| "data": models | |||
| }) | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| def post(self, provider: str): | |||
| if current_user.current_tenant.current_role not in ['admin', 'owner']: | |||
| raise Forbidden() | |||
| tenant_id = current_user.current_tenant_id | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('model', type=str, required=True, nullable=False, location='json') | |||
| parser.add_argument('model_type', type=str, required=True, nullable=False, | |||
| choices=[mt.value for mt in ModelType], location='json') | |||
| parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json') | |||
| args = parser.parse_args() | |||
| model_provider_service = ModelProviderService() | |||
| try: | |||
| model_provider_service.save_model_credentials( | |||
| tenant_id=tenant_id, | |||
| provider=provider, | |||
| model=args['model'], | |||
| model_type=args['model_type'], | |||
| credentials=args['credentials'] | |||
| ) | |||
| except CredentialsValidateFailedError as ex: | |||
| raise ValueError(str(ex)) | |||
| return {'result': 'success'}, 200 | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| def delete(self, provider: str): | |||
| if current_user.current_tenant.current_role not in ['admin', 'owner']: | |||
| raise Forbidden() | |||
| tenant_id = current_user.current_tenant_id | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('model', type=str, required=True, nullable=False, location='json') | |||
| parser.add_argument('model_type', type=str, required=True, nullable=False, | |||
| choices=[mt.value for mt in ModelType], location='json') | |||
| args = parser.parse_args() | |||
| model_provider_service = ModelProviderService() | |||
| model_provider_service.remove_model_credentials( | |||
| tenant_id=tenant_id, | |||
| provider=provider, | |||
| model=args['model'], | |||
| model_type=args['model_type'] | |||
| ) | |||
| return {'result': 'success'}, 204 | |||
| class ModelProviderModelCredentialApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| def get(self, provider: str): | |||
| tenant_id = current_user.current_tenant_id | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('model', type=str, required=True, nullable=False, location='args') | |||
| parser.add_argument('model_type', type=str, required=True, nullable=False, | |||
| choices=[mt.value for mt in ModelType], location='args') | |||
| args = parser.parse_args() | |||
| model_provider_service = ModelProviderService() | |||
| credentials = model_provider_service.get_model_credentials( | |||
| tenant_id=tenant_id, | |||
| provider=provider, | |||
| model_type=args['model_type'], | |||
| model=args['model'] | |||
| ) | |||
| return { | |||
| "credentials": credentials | |||
| } | |||
| class ModelProviderModelValidateApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| def post(self, provider: str): | |||
| tenant_id = current_user.current_tenant_id | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('model', type=str, required=True, nullable=False, location='json') | |||
| parser.add_argument('model_type', type=str, required=True, nullable=False, | |||
| choices=[mt.value for mt in ModelType], location='json') | |||
| parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json') | |||
| args = parser.parse_args() | |||
| model_provider_service = ModelProviderService() | |||
| result = True | |||
| error = None | |||
| try: | |||
| model_provider_service.model_credentials_validate( | |||
| tenant_id=tenant_id, | |||
| provider=provider, | |||
| model=args['model'], | |||
| model_type=args['model_type'], | |||
| credentials=args['credentials'] | |||
| ) | |||
| except CredentialsValidateFailedError as ex: | |||
| result = False | |||
| error = str(ex) | |||
| response = {'result': 'success' if result else 'error'} | |||
| if not result: | |||
| response['error'] = error | |||
| return response | |||
| class ModelProviderModelParameterRuleApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| def get(self, provider: str): | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('model', type=str, required=True, nullable=False, location='args') | |||
| args = parser.parse_args() | |||
| tenant_id = current_user.current_tenant_id | |||
| model_provider_service = ModelProviderService() | |||
| parameter_rules = model_provider_service.get_model_parameter_rules( | |||
| tenant_id=tenant_id, | |||
| provider=provider, | |||
| model=args['model'] | |||
| ) | |||
| return jsonable_encoder({ | |||
| "data": parameter_rules | |||
| }) | |||
| class ModelProviderAvailableModelApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| def get(self, model_type): | |||
| ModelType.value_of(model_type) | |||
| tenant_id = current_user.current_tenant_id | |||
| provider_service = ProviderService() | |||
| valid_models = provider_service.get_valid_model_list( | |||
| tenant_id=current_user.current_tenant_id, | |||
| model_provider_service = ModelProviderService() | |||
| models = model_provider_service.get_models_by_model_type( | |||
| tenant_id=tenant_id, | |||
| model_type=model_type | |||
| ) | |||
| return valid_models | |||
| return jsonable_encoder({ | |||
| "data": models | |||
| }) | |||
| api.add_resource(ModelProviderModelApi, '/workspaces/current/model-providers/<string:provider>/models') | |||
| api.add_resource(ModelProviderModelCredentialApi, | |||
| '/workspaces/current/model-providers/<string:provider>/models/credentials') | |||
| api.add_resource(ModelProviderModelValidateApi, | |||
| '/workspaces/current/model-providers/<string:provider>/models/credentials/validate') | |||
| api.add_resource(ModelProviderModelParameterRuleApi, | |||
| '/workspaces/current/model-providers/<string:provider>/models/parameter-rules') | |||
| api.add_resource(ModelProviderAvailableModelApi, '/workspaces/current/models/model-types/<string:model_type>') | |||
| api.add_resource(DefaultModelApi, '/workspaces/current/default-model') | |||
| api.add_resource(ValidModelApi, '/workspaces/current/models/model-type/<string:model_type>') | |||
| @@ -1,131 +0,0 @@ | |||
| # -*- coding:utf-8 -*- | |||
| from flask_login import current_user | |||
| from libs.login import login_required | |||
| from flask_restful import Resource, reqparse | |||
| from werkzeug.exceptions import Forbidden | |||
| from controllers.console import api | |||
| from controllers.console.setup import setup_required | |||
| from controllers.console.wraps import account_initialization_required | |||
| from core.model_providers.providers.base import CredentialsValidateFailedError | |||
| from models.provider import ProviderType | |||
| from services.provider_service import ProviderService | |||
| class ProviderListApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| def get(self): | |||
| tenant_id = current_user.current_tenant_id | |||
| """ | |||
| If the type is AZURE_OPENAI, decode and return the four fields of azure_api_type, azure_api_version:, | |||
| azure_api_base, azure_api_key as an object, where azure_api_key displays the first 6 bits in plaintext, and the | |||
| rest is replaced by * and the last two bits are displayed in plaintext | |||
| If the type is other, decode and return the Token field directly, the field displays the first 6 bits in | |||
| plaintext, the rest is replaced by * and the last two bits are displayed in plaintext | |||
| """ | |||
| provider_service = ProviderService() | |||
| provider_info_list = provider_service.get_provider_list(tenant_id) | |||
| provider_list = [ | |||
| { | |||
| 'provider_name': p['provider_name'], | |||
| 'provider_type': p['provider_type'], | |||
| 'is_valid': p['is_valid'], | |||
| 'last_used': p['last_used'], | |||
| 'is_enabled': p['is_valid'], | |||
| **({ | |||
| 'quota_type': p['quota_type'], | |||
| 'quota_limit': p['quota_limit'], | |||
| 'quota_used': p['quota_used'] | |||
| } if p['provider_type'] == ProviderType.SYSTEM.value else {}), | |||
| 'token': (p['config'] if p['provider_name'] != 'openai' else p['config']['openai_api_key']) | |||
| if p['config'] else None | |||
| } | |||
| for name, provider_info in provider_info_list.items() | |||
| for p in provider_info['providers'] | |||
| ] | |||
| return provider_list | |||
| class ProviderTokenApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| def post(self, provider): | |||
| # The role of the current user in the ta table must be admin or owner | |||
| if current_user.current_tenant.current_role not in ['admin', 'owner']: | |||
| raise Forbidden() | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('token', required=True, nullable=False, location='json') | |||
| args = parser.parse_args() | |||
| if provider == 'openai': | |||
| args['token'] = { | |||
| 'openai_api_key': args['token'] | |||
| } | |||
| provider_service = ProviderService() | |||
| try: | |||
| provider_service.save_custom_provider_config( | |||
| tenant_id=current_user.current_tenant_id, | |||
| provider_name=provider, | |||
| config=args['token'] | |||
| ) | |||
| except CredentialsValidateFailedError as ex: | |||
| raise ValueError(str(ex)) | |||
| return {'result': 'success'}, 201 | |||
| class ProviderTokenValidateApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| def post(self, provider): | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('token', required=True, nullable=False, location='json') | |||
| args = parser.parse_args() | |||
| provider_service = ProviderService() | |||
| if provider == 'openai': | |||
| args['token'] = { | |||
| 'openai_api_key': args['token'] | |||
| } | |||
| result = True | |||
| error = None | |||
| try: | |||
| provider_service.custom_provider_config_validate( | |||
| provider_name=provider, | |||
| config=args['token'] | |||
| ) | |||
| except CredentialsValidateFailedError as ex: | |||
| result = False | |||
| error = str(ex) | |||
| response = {'result': 'success' if result else 'error'} | |||
| if not result: | |||
| response['error'] = error | |||
| return response | |||
| api.add_resource(ProviderTokenApi, '/workspaces/current/providers/<provider>/token', | |||
| endpoint='workspaces_current_providers_token') # PUT for updating provider token | |||
| api.add_resource(ProviderTokenValidateApi, '/workspaces/current/providers/<provider>/token-validate', | |||
| endpoint='workspaces_current_providers_token_validate') # POST for validating provider token | |||
| api.add_resource(ProviderListApi, '/workspaces/current/providers') # GET for getting providers list | |||
| @@ -34,7 +34,6 @@ tenant_fields = { | |||
| 'status': fields.String, | |||
| 'created_at': TimestampField, | |||
| 'role': fields.String, | |||
| 'providers': fields.List(fields.Nested(provider_fields)), | |||
| 'in_trial': fields.Boolean, | |||
| 'trial_end_reason': fields.String, | |||
| 'custom_config': fields.Raw(attribute='custom_config'), | |||
| @@ -9,8 +9,8 @@ from controllers.service_api.app.error import AppUnavailableError, ProviderNotIn | |||
| ProviderModelCurrentlyNotSupportError, NoAudioUploadedError, AudioTooLargeError, UnsupportedAudioTypeError, \ | |||
| ProviderNotSupportSpeechToTextError | |||
| from controllers.service_api.wraps import AppApiResource | |||
| from core.model_providers.error import LLMBadRequestError, LLMAuthorizationError, LLMAPIUnavailableError, LLMAPIConnectionError, \ | |||
| LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError | |||
| from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError | |||
| from core.model_runtime.errors.invoke import InvokeError | |||
| from models.model import App, AppModelConfig | |||
| from services.audio_service import AudioService | |||
| from services.errors.audio import NoAudioUploadedServiceError, AudioTooLargeServiceError, \ | |||
| @@ -49,8 +49,7 @@ class AudioApi(AppApiResource): | |||
| raise ProviderQuotaExceededError() | |||
| except ModelCurrentlyNotSupportError: | |||
| raise ProviderModelCurrentlyNotSupportError() | |||
| except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, | |||
| LLMRateLimitError, LLMAuthorizationError) as e: | |||
| except InvokeError as e: | |||
| raise CompletionRequestError(str(e)) | |||
| except ValueError as e: | |||
| raise e | |||
| @@ -13,9 +13,10 @@ from controllers.service_api.app.error import AppUnavailableError, ProviderNotIn | |||
| ConversationCompletedError, CompletionRequestError, ProviderQuotaExceededError, \ | |||
| ProviderModelCurrentlyNotSupportError | |||
| from controllers.service_api.wraps import AppApiResource | |||
| from core.conversation_message_task import PubHandler | |||
| from core.model_providers.error import LLMBadRequestError, LLMAuthorizationError, LLMAPIUnavailableError, LLMAPIConnectionError, \ | |||
| LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError | |||
| from core.application_queue_manager import ApplicationQueueManager | |||
| from core.entities.application_entities import InvokeFrom | |||
| from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError | |||
| from core.model_runtime.errors.invoke import InvokeError | |||
| from libs.helper import uuid_value | |||
| from services.completion_service import CompletionService | |||
| @@ -47,7 +48,7 @@ class CompletionApi(AppApiResource): | |||
| app_model=app_model, | |||
| user=end_user, | |||
| args=args, | |||
| from_source='api', | |||
| invoke_from=InvokeFrom.SERVICE_API, | |||
| streaming=streaming, | |||
| ) | |||
| @@ -65,8 +66,7 @@ class CompletionApi(AppApiResource): | |||
| raise ProviderQuotaExceededError() | |||
| except ModelCurrentlyNotSupportError: | |||
| raise ProviderModelCurrentlyNotSupportError() | |||
| except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, | |||
| LLMRateLimitError, LLMAuthorizationError) as e: | |||
| except InvokeError as e: | |||
| raise CompletionRequestError(str(e)) | |||
| except ValueError as e: | |||
| raise e | |||
| @@ -80,7 +80,7 @@ class CompletionStopApi(AppApiResource): | |||
| if app_model.mode != 'completion': | |||
| raise AppUnavailableError() | |||
| PubHandler.stop(end_user, task_id) | |||
| ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id) | |||
| return {'result': 'success'}, 200 | |||
| @@ -112,7 +112,7 @@ class ChatApi(AppApiResource): | |||
| app_model=app_model, | |||
| user=end_user, | |||
| args=args, | |||
| from_source='api', | |||
| invoke_from=InvokeFrom.SERVICE_API, | |||
| streaming=streaming | |||
| ) | |||
| @@ -130,8 +130,7 @@ class ChatApi(AppApiResource): | |||
| raise ProviderQuotaExceededError() | |||
| except ModelCurrentlyNotSupportError: | |||
| raise ProviderModelCurrentlyNotSupportError() | |||
| except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, | |||
| LLMRateLimitError, LLMAuthorizationError) as e: | |||
| except InvokeError as e: | |||
| raise CompletionRequestError(str(e)) | |||
| except ValueError as e: | |||
| raise e | |||
| @@ -145,7 +144,7 @@ class ChatStopApi(AppApiResource): | |||
| if app_model.mode != 'chat': | |||
| raise NotChatAppError() | |||
| PubHandler.stop(end_user, task_id) | |||
| ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id) | |||
| return {'result': 'success'}, 200 | |||
| @@ -171,8 +170,7 @@ def compact_response(response: Union[dict, Generator]) -> Response: | |||
| yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n" | |||
| except ModelCurrentlyNotSupportError: | |||
| yield "data: " + json.dumps(api.handle_error(ProviderModelCurrentlyNotSupportError()).get_json()) + "\n\n" | |||
| except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, | |||
| LLMRateLimitError, LLMAuthorizationError) as e: | |||
| except InvokeError as e: | |||
| yield "data: " + json.dumps(api.handle_error(CompletionRequestError(str(e))).get_json()) + "\n\n" | |||
| except ValueError as e: | |||
| yield "data: " + json.dumps(api.handle_error(e).get_json()) + "\n\n" | |||
| @@ -4,11 +4,11 @@ import services.dataset_service | |||
| from controllers.service_api import api | |||
| from controllers.service_api.dataset.error import DatasetNameDuplicateError | |||
| from controllers.service_api.wraps import DatasetApiResource | |||
| from core.model_runtime.entities.model_entities import ModelType | |||
| from core.provider_manager import ProviderManager | |||
| from libs.login import current_user | |||
| from core.model_providers.models.entity.model_params import ModelType | |||
| from fields.dataset_fields import dataset_detail_fields | |||
| from services.dataset_service import DatasetService | |||
| from services.provider_service import ProviderService | |||
| def _validate_name(name): | |||
| @@ -27,12 +27,20 @@ class DatasetApi(DatasetApiResource): | |||
| datasets, total = DatasetService.get_datasets(page, limit, provider, | |||
| tenant_id, current_user) | |||
| # check embedding setting | |||
| provider_service = ProviderService() | |||
| valid_model_list = provider_service.get_valid_model_list(current_user.current_tenant_id, | |||
| ModelType.EMBEDDINGS.value) | |||
| provider_manager = ProviderManager() | |||
| configurations = provider_manager.get_configurations( | |||
| tenant_id=current_user.current_tenant_id | |||
| ) | |||
| embedding_models = configurations.get_models( | |||
| model_type=ModelType.TEXT_EMBEDDING, | |||
| only_active=True | |||
| ) | |||
| model_names = [] | |||
| for valid_model in valid_model_list: | |||
| model_names.append(f"{valid_model['model_name']}:{valid_model['model_provider']['provider_name']}") | |||
| for embedding_model in embedding_models: | |||
| model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}") | |||
| data = marshal(datasets, dataset_detail_fields) | |||
| for item in data: | |||
| if item['indexing_technique'] == 'high_quality': | |||
| @@ -13,7 +13,7 @@ from controllers.service_api.dataset.error import ArchivedDocumentImmutableError | |||
| NoFileUploadedError, TooManyFilesError | |||
| from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_resource_check | |||
| from libs.login import current_user | |||
| from core.model_providers.error import ProviderTokenNotInitError | |||
| from core.errors.error import ProviderTokenNotInitError | |||
| from extensions.ext_database import db | |||
| from fields.document_fields import document_fields, document_status_fields | |||
| from models.dataset import Dataset, Document, DocumentSegment | |||
| @@ -4,8 +4,9 @@ from werkzeug.exceptions import NotFound | |||
| from controllers.service_api import api | |||
| from controllers.service_api.app.error import ProviderNotInitializeError | |||
| from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_resource_check | |||
| from core.model_providers.error import ProviderTokenNotInitError, LLMBadRequestError | |||
| from core.model_providers.model_factory import ModelFactory | |||
| from core.errors.error import ProviderTokenNotInitError, LLMBadRequestError | |||
| from core.model_manager import ModelManager | |||
| from core.model_runtime.entities.model_entities import ModelType | |||
| from extensions.ext_database import db | |||
| from fields.segment_fields import segment_fields | |||
| from models.dataset import Dataset, DocumentSegment | |||
| @@ -35,10 +36,12 @@ class SegmentApi(DatasetApiResource): | |||
| # check embedding model setting | |||
| if dataset.indexing_technique == 'high_quality': | |||
| try: | |||
| ModelFactory.get_embedding_model( | |||
| model_manager = ModelManager() | |||
| model_manager.get_model_instance( | |||
| tenant_id=current_user.current_tenant_id, | |||
| model_provider_name=dataset.embedding_model_provider, | |||
| model_name=dataset.embedding_model | |||
| provider=dataset.embedding_model_provider, | |||
| model_type=ModelType.TEXT_EMBEDDING, | |||
| model=dataset.embedding_model | |||
| ) | |||
| except LLMBadRequestError: | |||
| raise ProviderNotInitializeError( | |||
| @@ -77,10 +80,12 @@ class SegmentApi(DatasetApiResource): | |||
| # check embedding model setting | |||
| if dataset.indexing_technique == 'high_quality': | |||
| try: | |||
| ModelFactory.get_embedding_model( | |||
| model_manager = ModelManager() | |||
| model_manager.get_model_instance( | |||
| tenant_id=current_user.current_tenant_id, | |||
| model_provider_name=dataset.embedding_model_provider, | |||
| model_name=dataset.embedding_model | |||
| provider=dataset.embedding_model_provider, | |||
| model_type=ModelType.TEXT_EMBEDDING, | |||
| model=dataset.embedding_model | |||
| ) | |||
| except LLMBadRequestError: | |||
| raise ProviderNotInitializeError( | |||
| @@ -167,10 +172,12 @@ class DatasetSegmentApi(DatasetApiResource): | |||
| if dataset.indexing_technique == 'high_quality': | |||
| # check embedding model setting | |||
| try: | |||
| ModelFactory.get_embedding_model( | |||
| model_manager = ModelManager() | |||
| model_manager.get_model_instance( | |||
| tenant_id=current_user.current_tenant_id, | |||
| model_provider_name=dataset.embedding_model_provider, | |||
| model_name=dataset.embedding_model | |||
| provider=dataset.embedding_model_provider, | |||
| model_type=ModelType.TEXT_EMBEDDING, | |||
| model=dataset.embedding_model | |||
| ) | |||
| except LLMBadRequestError: | |||
| raise ProviderNotInitializeError( | |||
| @@ -10,8 +10,8 @@ from controllers.web.error import AppUnavailableError, ProviderNotInitializeErro | |||
| ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, NoAudioUploadedError, AudioTooLargeError, \ | |||
| UnsupportedAudioTypeError, ProviderNotSupportSpeechToTextError | |||
| from controllers.web.wraps import WebApiResource | |||
| from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \ | |||
| LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError | |||
| from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError | |||
| from core.model_runtime.errors.invoke import InvokeError | |||
| from services.audio_service import AudioService | |||
| from services.errors.audio import NoAudioUploadedServiceError, AudioTooLargeServiceError, \ | |||
| UnsupportedAudioTypeServiceError, ProviderNotSupportSpeechToTextServiceError | |||
| @@ -51,8 +51,7 @@ class AudioApi(WebApiResource): | |||
| raise ProviderQuotaExceededError() | |||
| except ModelCurrentlyNotSupportError: | |||
| raise ProviderModelCurrentlyNotSupportError() | |||
| except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, | |||
| LLMRateLimitError, LLMAuthorizationError) as e: | |||
| except InvokeError as e: | |||
| raise CompletionRequestError(str(e)) | |||
| except ValueError as e: | |||
| raise e | |||
| @@ -13,9 +13,10 @@ from controllers.web.error import AppUnavailableError, ConversationCompletedErro | |||
| ProviderNotInitializeError, NotChatAppError, NotCompletionAppError, CompletionRequestError, \ | |||
| ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError | |||
| from controllers.web.wraps import WebApiResource | |||
| from core.conversation_message_task import PubHandler | |||
| from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \ | |||
| LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError | |||
| from core.application_queue_manager import ApplicationQueueManager | |||
| from core.entities.application_entities import InvokeFrom | |||
| from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError | |||
| from core.model_runtime.errors.invoke import InvokeError | |||
| from libs.helper import uuid_value | |||
| from services.completion_service import CompletionService | |||
| @@ -44,7 +45,7 @@ class CompletionApi(WebApiResource): | |||
| app_model=app_model, | |||
| user=end_user, | |||
| args=args, | |||
| from_source='api', | |||
| invoke_from=InvokeFrom.WEB_APP, | |||
| streaming=streaming | |||
| ) | |||
| @@ -62,8 +63,7 @@ class CompletionApi(WebApiResource): | |||
| raise ProviderQuotaExceededError() | |||
| except ModelCurrentlyNotSupportError: | |||
| raise ProviderModelCurrentlyNotSupportError() | |||
| except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, | |||
| LLMRateLimitError, LLMAuthorizationError) as e: | |||
| except InvokeError as e: | |||
| raise CompletionRequestError(str(e)) | |||
| except ValueError as e: | |||
| raise e | |||
| @@ -77,7 +77,7 @@ class CompletionStopApi(WebApiResource): | |||
| if app_model.mode != 'completion': | |||
| raise NotCompletionAppError() | |||
| PubHandler.stop(end_user, task_id) | |||
| ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.WEB_APP, end_user.id) | |||
| return {'result': 'success'}, 200 | |||
| @@ -105,7 +105,7 @@ class ChatApi(WebApiResource): | |||
| app_model=app_model, | |||
| user=end_user, | |||
| args=args, | |||
| from_source='api', | |||
| invoke_from=InvokeFrom.WEB_APP, | |||
| streaming=streaming | |||
| ) | |||
| @@ -123,8 +123,7 @@ class ChatApi(WebApiResource): | |||
| raise ProviderQuotaExceededError() | |||
| except ModelCurrentlyNotSupportError: | |||
| raise ProviderModelCurrentlyNotSupportError() | |||
| except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, | |||
| LLMRateLimitError, LLMAuthorizationError) as e: | |||
| except InvokeError as e: | |||
| raise CompletionRequestError(str(e)) | |||
| except ValueError as e: | |||
| raise e | |||
| @@ -138,7 +137,7 @@ class ChatStopApi(WebApiResource): | |||
| if app_model.mode != 'chat': | |||
| raise NotChatAppError() | |||
| PubHandler.stop(end_user, task_id) | |||
| ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.WEB_APP, end_user.id) | |||
| return {'result': 'success'}, 200 | |||
| @@ -164,8 +163,7 @@ def compact_response(response: Union[dict, Generator]) -> Response: | |||
| yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n" | |||
| except ModelCurrentlyNotSupportError: | |||
| yield "data: " + json.dumps(api.handle_error(ProviderModelCurrentlyNotSupportError()).get_json()) + "\n\n" | |||
| except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, | |||
| LLMRateLimitError, LLMAuthorizationError) as e: | |||
| except InvokeError as e: | |||
| yield "data: " + json.dumps(api.handle_error(CompletionRequestError(str(e))).get_json()) + "\n\n" | |||
| except ValueError as e: | |||
| yield "data: " + json.dumps(api.handle_error(e).get_json()) + "\n\n" | |||
| @@ -14,8 +14,9 @@ from controllers.web.error import NotChatAppError, CompletionRequestError, Provi | |||
| AppMoreLikeThisDisabledError, NotCompletionAppError, AppSuggestedQuestionsAfterAnswerDisabledError, \ | |||
| ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError | |||
| from controllers.web.wraps import WebApiResource | |||
| from core.model_providers.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \ | |||
| ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError | |||
| from core.entities.application_entities import InvokeFrom | |||
| from core.errors.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError | |||
| from core.model_runtime.errors.invoke import InvokeError | |||
| from libs.helper import uuid_value, TimestampField | |||
| from services.completion_service import CompletionService | |||
| from services.errors.app import MoreLikeThisDisabledError | |||
| @@ -117,7 +118,14 @@ class MessageMoreLikeThisApi(WebApiResource): | |||
| streaming = args['response_mode'] == 'streaming' | |||
| try: | |||
| response = CompletionService.generate_more_like_this(app_model, end_user, message_id, streaming, 'web_app') | |||
| response = CompletionService.generate_more_like_this( | |||
| app_model=app_model, | |||
| user=end_user, | |||
| message_id=message_id, | |||
| invoke_from=InvokeFrom.WEB_APP, | |||
| streaming=streaming | |||
| ) | |||
| return compact_response(response) | |||
| except MessageNotExistsError: | |||
| raise NotFound("Message Not Exists.") | |||
| @@ -129,8 +137,7 @@ class MessageMoreLikeThisApi(WebApiResource): | |||
| raise ProviderQuotaExceededError() | |||
| except ModelCurrentlyNotSupportError: | |||
| raise ProviderModelCurrentlyNotSupportError() | |||
| except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, | |||
| LLMRateLimitError, LLMAuthorizationError) as e: | |||
| except InvokeError as e: | |||
| raise CompletionRequestError(str(e)) | |||
| except ValueError as e: | |||
| raise e | |||
| @@ -157,8 +164,7 @@ def compact_response(response: Union[dict, Generator]) -> Response: | |||
| yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n" | |||
| except ModelCurrentlyNotSupportError: | |||
| yield "data: " + json.dumps(api.handle_error(ProviderModelCurrentlyNotSupportError()).get_json()) + "\n\n" | |||
| except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, | |||
| LLMRateLimitError, LLMAuthorizationError) as e: | |||
| except InvokeError as e: | |||
| yield "data: " + json.dumps(api.handle_error(CompletionRequestError(str(e))).get_json()) + "\n\n" | |||
| except ValueError as e: | |||
| yield "data: " + json.dumps(api.handle_error(e).get_json()) + "\n\n" | |||
| @@ -195,8 +201,7 @@ class MessageSuggestedQuestionApi(WebApiResource): | |||
| raise ProviderQuotaExceededError() | |||
| except ModelCurrentlyNotSupportError: | |||
| raise ProviderModelCurrentlyNotSupportError() | |||
| except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, | |||
| LLMRateLimitError, LLMAuthorizationError) as e: | |||
| except InvokeError as e: | |||
| raise CompletionRequestError(str(e)) | |||
| except Exception: | |||
| logging.exception("internal server error.") | |||
| @@ -0,0 +1,101 @@ | |||
| import logging | |||
| from typing import Optional, List | |||
| from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler | |||
| from core.model_runtime.callbacks.base_callback import Callback | |||
| from core.model_runtime.entities.llm_entities import LLMResultChunk, LLMResult | |||
| from core.model_runtime.entities.message_entities import PromptMessageTool, PromptMessage | |||
| from core.model_runtime.model_providers.__base.ai_model import AIModel | |||
| logger = logging.getLogger(__name__) | |||
| class AgentLLMCallback(Callback): | |||
| def __init__(self, agent_callback: AgentLoopGatherCallbackHandler) -> None: | |||
| self.agent_callback = agent_callback | |||
| def on_before_invoke(self, llm_instance: AIModel, model: str, credentials: dict, | |||
| prompt_messages: list[PromptMessage], model_parameters: dict, | |||
| tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, | |||
| stream: bool = True, user: Optional[str] = None) -> None: | |||
| """ | |||
| Before invoke callback | |||
| :param llm_instance: LLM instance | |||
| :param model: model name | |||
| :param credentials: model credentials | |||
| :param prompt_messages: prompt messages | |||
| :param model_parameters: model parameters | |||
| :param tools: tools for tool calling | |||
| :param stop: stop words | |||
| :param stream: is stream response | |||
| :param user: unique user id | |||
| """ | |||
| self.agent_callback.on_llm_before_invoke( | |||
| prompt_messages=prompt_messages | |||
| ) | |||
| def on_new_chunk(self, llm_instance: AIModel, chunk: LLMResultChunk, model: str, credentials: dict, | |||
| prompt_messages: list[PromptMessage], model_parameters: dict, | |||
| tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, | |||
| stream: bool = True, user: Optional[str] = None): | |||
| """ | |||
| On new chunk callback | |||
| :param llm_instance: LLM instance | |||
| :param chunk: chunk | |||
| :param model: model name | |||
| :param credentials: model credentials | |||
| :param prompt_messages: prompt messages | |||
| :param model_parameters: model parameters | |||
| :param tools: tools for tool calling | |||
| :param stop: stop words | |||
| :param stream: is stream response | |||
| :param user: unique user id | |||
| """ | |||
| pass | |||
| def on_after_invoke(self, llm_instance: AIModel, result: LLMResult, model: str, credentials: dict, | |||
| prompt_messages: list[PromptMessage], model_parameters: dict, | |||
| tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, | |||
| stream: bool = True, user: Optional[str] = None) -> None: | |||
| """ | |||
| After invoke callback | |||
| :param llm_instance: LLM instance | |||
| :param result: result | |||
| :param model: model name | |||
| :param credentials: model credentials | |||
| :param prompt_messages: prompt messages | |||
| :param model_parameters: model parameters | |||
| :param tools: tools for tool calling | |||
| :param stop: stop words | |||
| :param stream: is stream response | |||
| :param user: unique user id | |||
| """ | |||
| self.agent_callback.on_llm_after_invoke( | |||
| result=result | |||
| ) | |||
| def on_invoke_error(self, llm_instance: AIModel, ex: Exception, model: str, credentials: dict, | |||
| prompt_messages: list[PromptMessage], model_parameters: dict, | |||
| tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, | |||
| stream: bool = True, user: Optional[str] = None) -> None: | |||
| """ | |||
| Invoke error callback | |||
| :param llm_instance: LLM instance | |||
| :param ex: exception | |||
| :param model: model name | |||
| :param credentials: model credentials | |||
| :param prompt_messages: prompt messages | |||
| :param model_parameters: model parameters | |||
| :param tools: tools for tool calling | |||
| :param stop: stop words | |||
| :param stream: is stream response | |||
| :param user: unique user id | |||
| """ | |||
| self.agent_callback.on_llm_error( | |||
| error=ex | |||
| ) | |||
| @@ -1,28 +1,49 @@ | |||
| from typing import List | |||
| from typing import List, cast | |||
| from langchain.schema import BaseMessage | |||
| from core.model_providers.models.entity.message import to_prompt_messages | |||
| from core.model_providers.models.llm.base import BaseLLM | |||
| from core.entities.application_entities import ModelConfigEntity | |||
| from core.entities.message_entities import lc_messages_to_prompt_messages | |||
| from core.model_runtime.entities.message_entities import PromptMessage | |||
| from core.model_runtime.entities.model_entities import ModelPropertyKey | |||
| from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel | |||
| class CalcTokenMixin: | |||
| def get_num_tokens_from_messages(self, model_instance: BaseLLM, messages: List[BaseMessage], **kwargs) -> int: | |||
| return model_instance.get_num_tokens(to_prompt_messages(messages)) | |||
| def get_message_rest_tokens(self, model_instance: BaseLLM, messages: List[BaseMessage], **kwargs) -> int: | |||
| def get_message_rest_tokens(self, model_config: ModelConfigEntity, messages: List[PromptMessage], **kwargs) -> int: | |||
| """ | |||
| Got the rest tokens available for the model after excluding messages tokens and completion max tokens | |||
| :param llm: | |||
| :param model_config: | |||
| :param messages: | |||
| :return: | |||
| """ | |||
| llm_max_tokens = model_instance.model_rules.max_tokens.max | |||
| completion_max_tokens = model_instance.model_kwargs.max_tokens | |||
| used_tokens = self.get_num_tokens_from_messages(model_instance, messages, **kwargs) | |||
| rest_tokens = llm_max_tokens - completion_max_tokens - used_tokens | |||
| model_type_instance = model_config.provider_model_bundle.model_type_instance | |||
| model_type_instance = cast(LargeLanguageModel, model_type_instance) | |||
| model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE) | |||
| max_tokens = 0 | |||
| for parameter_rule in model_config.model_schema.parameter_rules: | |||
| if (parameter_rule.name == 'max_tokens' | |||
| or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')): | |||
| max_tokens = (model_config.parameters.get(parameter_rule.name) | |||
| or model_config.parameters.get(parameter_rule.use_template)) or 0 | |||
| if model_context_tokens is None: | |||
| return 0 | |||
| if max_tokens is None: | |||
| max_tokens = 0 | |||
| prompt_tokens = model_type_instance.get_num_tokens( | |||
| model_config.model, | |||
| model_config.credentials, | |||
| messages | |||
| ) | |||
| rest_tokens = model_context_tokens - max_tokens - prompt_tokens | |||
| return rest_tokens | |||
| @@ -1,4 +1,3 @@ | |||
| import json | |||
| from typing import Tuple, List, Any, Union, Sequence, Optional, cast | |||
| from langchain.agents import OpenAIFunctionsAgent, BaseSingleActionAgent | |||
| @@ -6,13 +5,14 @@ from langchain.agents.openai_functions_agent.base import _format_intermediate_st | |||
| from langchain.callbacks.base import BaseCallbackManager | |||
| from langchain.callbacks.manager import Callbacks | |||
| from langchain.prompts.chat import BaseMessagePromptTemplate | |||
| from langchain.schema import AgentAction, AgentFinish, SystemMessage, Generation, LLMResult, AIMessage | |||
| from langchain.schema.language_model import BaseLanguageModel | |||
| from langchain.schema import AgentAction, AgentFinish, SystemMessage, AIMessage | |||
| from langchain.tools import BaseTool | |||
| from pydantic import root_validator | |||
| from core.model_providers.models.entity.message import to_prompt_messages | |||
| from core.model_providers.models.llm.base import BaseLLM | |||
| from core.entities.application_entities import ModelConfigEntity | |||
| from core.model_manager import ModelInstance | |||
| from core.entities.message_entities import lc_messages_to_prompt_messages | |||
| from core.model_runtime.entities.message_entities import PromptMessageTool | |||
| from core.third_party.langchain.llms.fake import FakeLLM | |||
| @@ -20,7 +20,7 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent): | |||
| """ | |||
| An Multi Dataset Retrieve Agent driven by Router. | |||
| """ | |||
| model_instance: BaseLLM | |||
| model_config: ModelConfigEntity | |||
| class Config: | |||
| """Configuration for this pydantic object.""" | |||
| @@ -81,8 +81,7 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent): | |||
| agent_decision.return_values['output'] = '' | |||
| return agent_decision | |||
| except Exception as e: | |||
| new_exception = self.model_instance.handle_exceptions(e) | |||
| raise new_exception | |||
| raise e | |||
| def real_plan( | |||
| self, | |||
| @@ -106,16 +105,39 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent): | |||
| full_inputs = dict(**selected_inputs, agent_scratchpad=agent_scratchpad) | |||
| prompt = self.prompt.format_prompt(**full_inputs) | |||
| messages = prompt.to_messages() | |||
| prompt_messages = to_prompt_messages(messages) | |||
| result = self.model_instance.run( | |||
| messages=prompt_messages, | |||
| functions=self.functions, | |||
| prompt_messages = lc_messages_to_prompt_messages(messages) | |||
| model_instance = ModelInstance( | |||
| provider_model_bundle=self.model_config.provider_model_bundle, | |||
| model=self.model_config.model, | |||
| ) | |||
| tools = [] | |||
| for function in self.functions: | |||
| tool = PromptMessageTool( | |||
| **function | |||
| ) | |||
| tools.append(tool) | |||
| result = model_instance.invoke_llm( | |||
| prompt_messages=prompt_messages, | |||
| tools=tools, | |||
| stream=False, | |||
| model_parameters={ | |||
| 'temperature': 0.2, | |||
| 'top_p': 0.3, | |||
| 'max_tokens': 1500 | |||
| } | |||
| ) | |||
| ai_message = AIMessage( | |||
| content=result.content, | |||
| content=result.message.content or "", | |||
| additional_kwargs={ | |||
| 'function_call': result.function_call | |||
| 'function_call': { | |||
| 'id': result.message.tool_calls[0].id, | |||
| **result.message.tool_calls[0].function.dict() | |||
| } if result.message.tool_calls else None | |||
| } | |||
| ) | |||
| @@ -133,7 +155,7 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent): | |||
| @classmethod | |||
| def from_llm_and_tools( | |||
| cls, | |||
| model_instance: BaseLLM, | |||
| model_config: ModelConfigEntity, | |||
| tools: Sequence[BaseTool], | |||
| callback_manager: Optional[BaseCallbackManager] = None, | |||
| extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None, | |||
| @@ -147,7 +169,7 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent): | |||
| system_message=system_message, | |||
| ) | |||
| return cls( | |||
| model_instance=model_instance, | |||
| model_config=model_config, | |||
| llm=FakeLLM(response=''), | |||
| prompt=prompt, | |||
| tools=tools, | |||
| @@ -1,4 +1,4 @@ | |||
| from typing import List, Tuple, Any, Union, Sequence, Optional | |||
| from typing import List, Tuple, Any, Union, Sequence, Optional, cast | |||
| from langchain.agents import OpenAIFunctionsAgent, BaseSingleActionAgent | |||
| from langchain.agents.openai_functions_agent.base import _parse_ai_message, \ | |||
| @@ -13,18 +13,23 @@ from langchain.schema import AgentAction, AgentFinish, SystemMessage, AIMessage, | |||
| from langchain.tools import BaseTool | |||
| from pydantic import root_validator | |||
| from core.agent.agent.agent_llm_callback import AgentLLMCallback | |||
| from core.agent.agent.calc_token_mixin import ExceededLLMTokensLimitError, CalcTokenMixin | |||
| from core.chain.llm_chain import LLMChain | |||
| from core.model_providers.models.entity.message import to_prompt_messages | |||
| from core.model_providers.models.llm.base import BaseLLM | |||
| from core.entities.application_entities import ModelConfigEntity | |||
| from core.model_manager import ModelInstance | |||
| from core.entities.message_entities import lc_messages_to_prompt_messages | |||
| from core.model_runtime.entities.message_entities import PromptMessageTool, PromptMessage | |||
| from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel | |||
| from core.third_party.langchain.llms.fake import FakeLLM | |||
| class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixin): | |||
| moving_summary_buffer: str = "" | |||
| moving_summary_index: int = 0 | |||
| summary_model_instance: BaseLLM = None | |||
| model_instance: BaseLLM | |||
| summary_model_config: ModelConfigEntity = None | |||
| model_config: ModelConfigEntity | |||
| agent_llm_callback: Optional[AgentLLMCallback] = None | |||
| class Config: | |||
| """Configuration for this pydantic object.""" | |||
| @@ -38,13 +43,14 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixi | |||
| @classmethod | |||
| def from_llm_and_tools( | |||
| cls, | |||
| model_instance: BaseLLM, | |||
| model_config: ModelConfigEntity, | |||
| tools: Sequence[BaseTool], | |||
| callback_manager: Optional[BaseCallbackManager] = None, | |||
| extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None, | |||
| system_message: Optional[SystemMessage] = SystemMessage( | |||
| content="You are a helpful AI assistant." | |||
| ), | |||
| agent_llm_callback: Optional[AgentLLMCallback] = None, | |||
| **kwargs: Any, | |||
| ) -> BaseSingleActionAgent: | |||
| prompt = cls.create_prompt( | |||
| @@ -52,11 +58,12 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixi | |||
| system_message=system_message, | |||
| ) | |||
| return cls( | |||
| model_instance=model_instance, | |||
| model_config=model_config, | |||
| llm=FakeLLM(response=''), | |||
| prompt=prompt, | |||
| tools=tools, | |||
| callback_manager=callback_manager, | |||
| agent_llm_callback=agent_llm_callback, | |||
| **kwargs, | |||
| ) | |||
| @@ -67,28 +74,49 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixi | |||
| :param query: | |||
| :return: | |||
| """ | |||
| original_max_tokens = self.model_instance.model_kwargs.max_tokens | |||
| self.model_instance.model_kwargs.max_tokens = 40 | |||
| original_max_tokens = 0 | |||
| for parameter_rule in self.model_config.model_schema.parameter_rules: | |||
| if (parameter_rule.name == 'max_tokens' | |||
| or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')): | |||
| original_max_tokens = (self.model_config.parameters.get(parameter_rule.name) | |||
| or self.model_config.parameters.get(parameter_rule.use_template)) or 0 | |||
| self.model_config.parameters['max_tokens'] = 40 | |||
| prompt = self.prompt.format_prompt(input=query, agent_scratchpad=[]) | |||
| messages = prompt.to_messages() | |||
| try: | |||
| prompt_messages = to_prompt_messages(messages) | |||
| result = self.model_instance.run( | |||
| messages=prompt_messages, | |||
| functions=self.functions, | |||
| callbacks=None | |||
| prompt_messages = lc_messages_to_prompt_messages(messages) | |||
| model_instance = ModelInstance( | |||
| provider_model_bundle=self.model_config.provider_model_bundle, | |||
| model=self.model_config.model, | |||
| ) | |||
| except Exception as e: | |||
| new_exception = self.model_instance.handle_exceptions(e) | |||
| raise new_exception | |||
| function_call = result.function_call | |||
| tools = [] | |||
| for function in self.functions: | |||
| tool = PromptMessageTool( | |||
| **function | |||
| ) | |||
| tools.append(tool) | |||
| result = model_instance.invoke_llm( | |||
| prompt_messages=prompt_messages, | |||
| tools=tools, | |||
| stream=False, | |||
| model_parameters={ | |||
| 'temperature': 0.2, | |||
| 'top_p': 0.3, | |||
| 'max_tokens': 1500 | |||
| } | |||
| ) | |||
| except Exception as e: | |||
| raise e | |||
| self.model_instance.model_kwargs.max_tokens = original_max_tokens | |||
| self.model_config.parameters['max_tokens'] = original_max_tokens | |||
| return True if function_call else False | |||
| return True if result.message.tool_calls else False | |||
| def plan( | |||
| self, | |||
| @@ -113,22 +141,46 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixi | |||
| prompt = self.prompt.format_prompt(**full_inputs) | |||
| messages = prompt.to_messages() | |||
| prompt_messages = lc_messages_to_prompt_messages(messages) | |||
| # summarize messages if rest_tokens < 0 | |||
| try: | |||
| messages = self.summarize_messages_if_needed(messages, functions=self.functions) | |||
| prompt_messages = self.summarize_messages_if_needed(prompt_messages, functions=self.functions) | |||
| except ExceededLLMTokensLimitError as e: | |||
| return AgentFinish(return_values={"output": str(e)}, log=str(e)) | |||
| prompt_messages = to_prompt_messages(messages) | |||
| result = self.model_instance.run( | |||
| messages=prompt_messages, | |||
| functions=self.functions, | |||
| model_instance = ModelInstance( | |||
| provider_model_bundle=self.model_config.provider_model_bundle, | |||
| model=self.model_config.model, | |||
| ) | |||
| tools = [] | |||
| for function in self.functions: | |||
| tool = PromptMessageTool( | |||
| **function | |||
| ) | |||
| tools.append(tool) | |||
| result = model_instance.invoke_llm( | |||
| prompt_messages=prompt_messages, | |||
| tools=tools, | |||
| stream=False, | |||
| callbacks=[self.agent_llm_callback] if self.agent_llm_callback else [], | |||
| model_parameters={ | |||
| 'temperature': 0.2, | |||
| 'top_p': 0.3, | |||
| 'max_tokens': 1500 | |||
| } | |||
| ) | |||
| ai_message = AIMessage( | |||
| content=result.content, | |||
| content=result.message.content or "", | |||
| additional_kwargs={ | |||
| 'function_call': result.function_call | |||
| 'function_call': { | |||
| 'id': result.message.tool_calls[0].id, | |||
| **result.message.tool_calls[0].function.dict() | |||
| } if result.message.tool_calls else None | |||
| } | |||
| ) | |||
| agent_decision = _parse_ai_message(ai_message) | |||
| @@ -158,9 +210,14 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixi | |||
| except ValueError: | |||
| return AgentFinish({"output": "I'm sorry, I don't know how to respond to that."}, "") | |||
| def summarize_messages_if_needed(self, messages: List[BaseMessage], **kwargs) -> List[BaseMessage]: | |||
| def summarize_messages_if_needed(self, messages: List[PromptMessage], **kwargs) -> List[PromptMessage]: | |||
| # calculate rest tokens and summarize previous function observation messages if rest_tokens < 0 | |||
| rest_tokens = self.get_message_rest_tokens(self.model_instance, messages, **kwargs) | |||
| rest_tokens = self.get_message_rest_tokens( | |||
| self.model_config, | |||
| messages, | |||
| **kwargs | |||
| ) | |||
| rest_tokens = rest_tokens - 20 # to deal with the inaccuracy of rest_tokens | |||
| if rest_tokens >= 0: | |||
| return messages | |||
| @@ -210,19 +267,19 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixi | |||
| ai_prefix="AI", | |||
| ) | |||
| chain = LLMChain(model_instance=self.summary_model_instance, prompt=SUMMARY_PROMPT) | |||
| chain = LLMChain(model_config=self.summary_model_config, prompt=SUMMARY_PROMPT) | |||
| return chain.predict(summary=existing_summary, new_lines=new_lines) | |||
| def get_num_tokens_from_messages(self, model_instance: BaseLLM, messages: List[BaseMessage], **kwargs) -> int: | |||
| def get_num_tokens_from_messages(self, model_config: ModelConfigEntity, messages: List[BaseMessage], **kwargs) -> int: | |||
| """Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package. | |||
| Official documentation: https://github.com/openai/openai-cookbook/blob/ | |||
| main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb""" | |||
| if model_instance.model_provider.provider_name == 'azure_openai': | |||
| model = model_instance.base_model_name | |||
| if model_config.provider == 'azure_openai': | |||
| model = model_config.model | |||
| model = model.replace("gpt-35", "gpt-3.5") | |||
| else: | |||
| model = model_instance.base_model_name | |||
| model = model_config.credentials.get("base_model_name") | |||
| tiktoken_ = _import_tiktoken() | |||
| try: | |||
| @@ -1,158 +0,0 @@ | |||
| import json | |||
| from typing import Tuple, List, Any, Union, Sequence, Optional, cast | |||
| from langchain.agents import OpenAIFunctionsAgent, BaseSingleActionAgent | |||
| from langchain.agents.openai_functions_agent.base import _format_intermediate_steps, _parse_ai_message | |||
| from langchain.callbacks.base import BaseCallbackManager | |||
| from langchain.callbacks.manager import Callbacks | |||
| from langchain.prompts.chat import BaseMessagePromptTemplate | |||
| from langchain.schema import AgentAction, AgentFinish, SystemMessage, Generation, LLMResult, AIMessage | |||
| from langchain.schema.language_model import BaseLanguageModel | |||
| from langchain.tools import BaseTool | |||
| from pydantic import root_validator | |||
| from core.model_providers.models.entity.message import to_prompt_messages | |||
| from core.model_providers.models.llm.base import BaseLLM | |||
| from core.third_party.langchain.llms.fake import FakeLLM | |||
| from core.tool.dataset_retriever_tool import DatasetRetrieverTool | |||
| class MultiDatasetRouterAgent(OpenAIFunctionsAgent): | |||
| """ | |||
| An Multi Dataset Retrieve Agent driven by Router. | |||
| """ | |||
| model_instance: BaseLLM | |||
| class Config: | |||
| """Configuration for this pydantic object.""" | |||
| arbitrary_types_allowed = True | |||
| @root_validator | |||
| def validate_llm(cls, values: dict) -> dict: | |||
| return values | |||
| def should_use_agent(self, query: str): | |||
| """ | |||
| return should use agent | |||
| :param query: | |||
| :return: | |||
| """ | |||
| return True | |||
| def plan( | |||
| self, | |||
| intermediate_steps: List[Tuple[AgentAction, str]], | |||
| callbacks: Callbacks = None, | |||
| **kwargs: Any, | |||
| ) -> Union[AgentAction, AgentFinish]: | |||
| """Given input, decided what to do. | |||
| Args: | |||
| intermediate_steps: Steps the LLM has taken to date, along with observations | |||
| **kwargs: User inputs. | |||
| Returns: | |||
| Action specifying what tool to use. | |||
| """ | |||
| if len(self.tools) == 0: | |||
| return AgentFinish(return_values={"output": ''}, log='') | |||
| elif len(self.tools) == 1: | |||
| tool = next(iter(self.tools)) | |||
| tool = cast(DatasetRetrieverTool, tool) | |||
| rst = tool.run(tool_input={'query': kwargs['input']}) | |||
| # output = '' | |||
| # rst_json = json.loads(rst) | |||
| # for item in rst_json: | |||
| # output += f'{item["content"]}\n' | |||
| return AgentFinish(return_values={"output": rst}, log=rst) | |||
| if intermediate_steps: | |||
| _, observation = intermediate_steps[-1] | |||
| return AgentFinish(return_values={"output": observation}, log=observation) | |||
| try: | |||
| agent_decision = self.real_plan(intermediate_steps, callbacks, **kwargs) | |||
| if isinstance(agent_decision, AgentAction): | |||
| tool_inputs = agent_decision.tool_input | |||
| if isinstance(tool_inputs, dict) and 'query' in tool_inputs and 'chat_history' not in kwargs: | |||
| tool_inputs['query'] = kwargs['input'] | |||
| agent_decision.tool_input = tool_inputs | |||
| else: | |||
| agent_decision.return_values['output'] = '' | |||
| return agent_decision | |||
| except Exception as e: | |||
| new_exception = self.model_instance.handle_exceptions(e) | |||
| raise new_exception | |||
| def real_plan( | |||
| self, | |||
| intermediate_steps: List[Tuple[AgentAction, str]], | |||
| callbacks: Callbacks = None, | |||
| **kwargs: Any, | |||
| ) -> Union[AgentAction, AgentFinish]: | |||
| """Given input, decided what to do. | |||
| Args: | |||
| intermediate_steps: Steps the LLM has taken to date, along with observations | |||
| **kwargs: User inputs. | |||
| Returns: | |||
| Action specifying what tool to use. | |||
| """ | |||
| agent_scratchpad = _format_intermediate_steps(intermediate_steps) | |||
| selected_inputs = { | |||
| k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad" | |||
| } | |||
| full_inputs = dict(**selected_inputs, agent_scratchpad=agent_scratchpad) | |||
| prompt = self.prompt.format_prompt(**full_inputs) | |||
| messages = prompt.to_messages() | |||
| prompt_messages = to_prompt_messages(messages) | |||
| result = self.model_instance.run( | |||
| messages=prompt_messages, | |||
| functions=self.functions, | |||
| ) | |||
| ai_message = AIMessage( | |||
| content=result.content, | |||
| additional_kwargs={ | |||
| 'function_call': result.function_call | |||
| } | |||
| ) | |||
| agent_decision = _parse_ai_message(ai_message) | |||
| return agent_decision | |||
| async def aplan( | |||
| self, | |||
| intermediate_steps: List[Tuple[AgentAction, str]], | |||
| callbacks: Callbacks = None, | |||
| **kwargs: Any, | |||
| ) -> Union[AgentAction, AgentFinish]: | |||
| raise NotImplementedError() | |||
| @classmethod | |||
| def from_llm_and_tools( | |||
| cls, | |||
| model_instance: BaseLLM, | |||
| tools: Sequence[BaseTool], | |||
| callback_manager: Optional[BaseCallbackManager] = None, | |||
| extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None, | |||
| system_message: Optional[SystemMessage] = SystemMessage( | |||
| content="You are a helpful AI assistant." | |||
| ), | |||
| **kwargs: Any, | |||
| ) -> BaseSingleActionAgent: | |||
| prompt = cls.create_prompt( | |||
| extra_prompt_messages=extra_prompt_messages, | |||
| system_message=system_message, | |||
| ) | |||
| return cls( | |||
| model_instance=model_instance, | |||
| llm=FakeLLM(response=''), | |||
| prompt=prompt, | |||
| tools=tools, | |||
| callback_manager=callback_manager, | |||
| **kwargs, | |||
| ) | |||
| @@ -12,9 +12,7 @@ from langchain.tools import BaseTool | |||
| from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX | |||
| from core.chain.llm_chain import LLMChain | |||
| from core.model_providers.models.entity.model_params import ModelMode | |||
| from core.model_providers.models.llm.base import BaseLLM | |||
| from core.tool.dataset_retriever_tool import DatasetRetrieverTool | |||
| from core.entities.application_entities import ModelConfigEntity | |||
| FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input). | |||
| The nouns in the format of "Thought", "Action", "Action Input", "Final Answer" must be expressed in English. | |||
| @@ -69,10 +67,10 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent): | |||
| return True | |||
| def plan( | |||
| self, | |||
| intermediate_steps: List[Tuple[AgentAction, str]], | |||
| callbacks: Callbacks = None, | |||
| **kwargs: Any, | |||
| self, | |||
| intermediate_steps: List[Tuple[AgentAction, str]], | |||
| callbacks: Callbacks = None, | |||
| **kwargs: Any, | |||
| ) -> Union[AgentAction, AgentFinish]: | |||
| """Given input, decided what to do. | |||
| @@ -101,8 +99,7 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent): | |||
| try: | |||
| full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs) | |||
| except Exception as e: | |||
| new_exception = self.llm_chain.model_instance.handle_exceptions(e) | |||
| raise new_exception | |||
| raise e | |||
| try: | |||
| agent_decision = self.output_parser.parse(full_output) | |||
| @@ -119,6 +116,7 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent): | |||
| except OutputParserException: | |||
| return AgentFinish({"output": "I'm sorry, the answer of model is invalid, " | |||
| "I don't know how to respond to that."}, "") | |||
| @classmethod | |||
| def create_prompt( | |||
| cls, | |||
| @@ -182,7 +180,7 @@ Thought: {agent_scratchpad} | |||
| return PromptTemplate(template=template, input_variables=input_variables) | |||
| def _construct_scratchpad( | |||
| self, intermediate_steps: List[Tuple[AgentAction, str]] | |||
| self, intermediate_steps: List[Tuple[AgentAction, str]] | |||
| ) -> str: | |||
| agent_scratchpad = "" | |||
| for action, observation in intermediate_steps: | |||
| @@ -193,7 +191,7 @@ Thought: {agent_scratchpad} | |||
| raise ValueError("agent_scratchpad should be of type string.") | |||
| if agent_scratchpad: | |||
| llm_chain = cast(LLMChain, self.llm_chain) | |||
| if llm_chain.model_instance.model_mode == ModelMode.CHAT: | |||
| if llm_chain.model_config.mode == "chat": | |||
| return ( | |||
| f"This was your previous work " | |||
| f"(but I haven't seen any of it! I only see what " | |||
| @@ -207,7 +205,7 @@ Thought: {agent_scratchpad} | |||
| @classmethod | |||
| def from_llm_and_tools( | |||
| cls, | |||
| model_instance: BaseLLM, | |||
| model_config: ModelConfigEntity, | |||
| tools: Sequence[BaseTool], | |||
| callback_manager: Optional[BaseCallbackManager] = None, | |||
| output_parser: Optional[AgentOutputParser] = None, | |||
| @@ -221,7 +219,7 @@ Thought: {agent_scratchpad} | |||
| ) -> Agent: | |||
| """Construct an agent from an LLM and tools.""" | |||
| cls._validate_tools(tools) | |||
| if model_instance.model_mode == ModelMode.CHAT: | |||
| if model_config.mode == "chat": | |||
| prompt = cls.create_prompt( | |||
| tools, | |||
| prefix=prefix, | |||
| @@ -238,10 +236,16 @@ Thought: {agent_scratchpad} | |||
| format_instructions=format_instructions, | |||
| input_variables=input_variables | |||
| ) | |||
| llm_chain = LLMChain( | |||
| model_instance=model_instance, | |||
| model_config=model_config, | |||
| prompt=prompt, | |||
| callback_manager=callback_manager, | |||
| parameters={ | |||
| 'temperature': 0.2, | |||
| 'top_p': 0.3, | |||
| 'max_tokens': 1500 | |||
| } | |||
| ) | |||
| tool_names = [tool.name for tool in tools] | |||
| _output_parser = output_parser | |||
| @@ -13,10 +13,11 @@ from langchain.schema import AgentAction, AgentFinish, AIMessage, HumanMessage, | |||
| from langchain.tools import BaseTool | |||
| from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX | |||
| from core.agent.agent.agent_llm_callback import AgentLLMCallback | |||
| from core.agent.agent.calc_token_mixin import CalcTokenMixin, ExceededLLMTokensLimitError | |||
| from core.chain.llm_chain import LLMChain | |||
| from core.model_providers.models.entity.model_params import ModelMode | |||
| from core.model_providers.models.llm.base import BaseLLM | |||
| from core.entities.application_entities import ModelConfigEntity | |||
| from core.entities.message_entities import lc_messages_to_prompt_messages | |||
| FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input). | |||
| The nouns in the format of "Thought", "Action", "Action Input", "Final Answer" must be expressed in English. | |||
| @@ -54,7 +55,7 @@ Action: | |||
| class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin): | |||
| moving_summary_buffer: str = "" | |||
| moving_summary_index: int = 0 | |||
| summary_model_instance: BaseLLM = None | |||
| summary_model_config: ModelConfigEntity = None | |||
| class Config: | |||
| """Configuration for this pydantic object.""" | |||
| @@ -82,7 +83,7 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin): | |||
| Args: | |||
| intermediate_steps: Steps the LLM has taken to date, | |||
| along with observations | |||
| along with observatons | |||
| callbacks: Callbacks to run. | |||
| **kwargs: User inputs. | |||
| @@ -96,15 +97,16 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin): | |||
| if prompts: | |||
| messages = prompts[0].to_messages() | |||
| rest_tokens = self.get_message_rest_tokens(self.llm_chain.model_instance, messages) | |||
| prompt_messages = lc_messages_to_prompt_messages(messages) | |||
| rest_tokens = self.get_message_rest_tokens(self.llm_chain.model_config, prompt_messages) | |||
| if rest_tokens < 0: | |||
| full_inputs = self.summarize_messages(intermediate_steps, **kwargs) | |||
| try: | |||
| full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs) | |||
| except Exception as e: | |||
| new_exception = self.llm_chain.model_instance.handle_exceptions(e) | |||
| raise new_exception | |||
| raise e | |||
| try: | |||
| agent_decision = self.output_parser.parse(full_output) | |||
| @@ -119,7 +121,7 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin): | |||
| "I don't know how to respond to that."}, "") | |||
| def summarize_messages(self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs): | |||
| if len(intermediate_steps) >= 2 and self.summary_model_instance: | |||
| if len(intermediate_steps) >= 2 and self.summary_model_config: | |||
| should_summary_intermediate_steps = intermediate_steps[self.moving_summary_index:-1] | |||
| should_summary_messages = [AIMessage(content=observation) | |||
| for _, observation in should_summary_intermediate_steps] | |||
| @@ -153,7 +155,7 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin): | |||
| ai_prefix="AI", | |||
| ) | |||
| chain = LLMChain(model_instance=self.summary_model_instance, prompt=SUMMARY_PROMPT) | |||
| chain = LLMChain(model_config=self.summary_model_config, prompt=SUMMARY_PROMPT) | |||
| return chain.predict(summary=existing_summary, new_lines=new_lines) | |||
| @classmethod | |||
| @@ -229,7 +231,7 @@ Thought: {agent_scratchpad} | |||
| raise ValueError("agent_scratchpad should be of type string.") | |||
| if agent_scratchpad: | |||
| llm_chain = cast(LLMChain, self.llm_chain) | |||
| if llm_chain.model_instance.model_mode == ModelMode.CHAT: | |||
| if llm_chain.model_config.mode == "chat": | |||
| return ( | |||
| f"This was your previous work " | |||
| f"(but I haven't seen any of it! I only see what " | |||
| @@ -243,7 +245,7 @@ Thought: {agent_scratchpad} | |||
| @classmethod | |||
| def from_llm_and_tools( | |||
| cls, | |||
| model_instance: BaseLLM, | |||
| model_config: ModelConfigEntity, | |||
| tools: Sequence[BaseTool], | |||
| callback_manager: Optional[BaseCallbackManager] = None, | |||
| output_parser: Optional[AgentOutputParser] = None, | |||
| @@ -253,11 +255,12 @@ Thought: {agent_scratchpad} | |||
| format_instructions: str = FORMAT_INSTRUCTIONS, | |||
| input_variables: Optional[List[str]] = None, | |||
| memory_prompts: Optional[List[BasePromptTemplate]] = None, | |||
| agent_llm_callback: Optional[AgentLLMCallback] = None, | |||
| **kwargs: Any, | |||
| ) -> Agent: | |||
| """Construct an agent from an LLM and tools.""" | |||
| cls._validate_tools(tools) | |||
| if model_instance.model_mode == ModelMode.CHAT: | |||
| if model_config.mode == "chat": | |||
| prompt = cls.create_prompt( | |||
| tools, | |||
| prefix=prefix, | |||
| @@ -275,9 +278,15 @@ Thought: {agent_scratchpad} | |||
| input_variables=input_variables, | |||
| ) | |||
| llm_chain = LLMChain( | |||
| model_instance=model_instance, | |||
| model_config=model_config, | |||
| prompt=prompt, | |||
| callback_manager=callback_manager, | |||
| agent_llm_callback=agent_llm_callback, | |||
| parameters={ | |||
| 'temperature': 0.2, | |||
| 'top_p': 0.3, | |||
| 'max_tokens': 1500 | |||
| } | |||
| ) | |||
| tool_names = [tool.name for tool in tools] | |||
| _output_parser = output_parser | |||
| @@ -4,10 +4,10 @@ from typing import Union, Optional | |||
| from langchain.agents import BaseSingleActionAgent, BaseMultiActionAgent | |||
| from langchain.callbacks.manager import Callbacks | |||
| from langchain.memory.chat_memory import BaseChatMemory | |||
| from langchain.tools import BaseTool | |||
| from pydantic import BaseModel, Extra | |||
| from core.agent.agent.agent_llm_callback import AgentLLMCallback | |||
| from core.agent.agent.multi_dataset_router_agent import MultiDatasetRouterAgent | |||
| from core.agent.agent.openai_function_call import AutoSummarizingOpenAIFunctionCallAgent | |||
| from core.agent.agent.output_parser.structured_chat import StructuredChatOutputParser | |||
| @@ -15,9 +15,11 @@ from core.agent.agent.structed_multi_dataset_router_agent import StructuredMulti | |||
| from core.agent.agent.structured_chat import AutoSummarizingStructuredChatAgent | |||
| from langchain.agents import AgentExecutor as LCAgentExecutor | |||
| from core.entities.application_entities import ModelConfigEntity | |||
| from core.entities.message_entities import prompt_messages_to_lc_messages | |||
| from core.helper import moderation | |||
| from core.model_providers.error import LLMError | |||
| from core.model_providers.models.llm.base import BaseLLM | |||
| from core.memory.token_buffer_memory import TokenBufferMemory | |||
| from core.model_runtime.errors.invoke import InvokeError | |||
| from core.tool.dataset_multi_retriever_tool import DatasetMultiRetrieverTool | |||
| from core.tool.dataset_retriever_tool import DatasetRetrieverTool | |||
| @@ -31,14 +33,15 @@ class PlanningStrategy(str, enum.Enum): | |||
| class AgentConfiguration(BaseModel): | |||
| strategy: PlanningStrategy | |||
| model_instance: BaseLLM | |||
| model_config: ModelConfigEntity | |||
| tools: list[BaseTool] | |||
| summary_model_instance: BaseLLM = None | |||
| memory: Optional[BaseChatMemory] = None | |||
| summary_model_config: Optional[ModelConfigEntity] = None | |||
| memory: Optional[TokenBufferMemory] = None | |||
| callbacks: Callbacks = None | |||
| max_iterations: int = 6 | |||
| max_execution_time: Optional[float] = None | |||
| early_stopping_method: str = "generate" | |||
| agent_llm_callback: Optional[AgentLLMCallback] = None | |||
| # `generate` will continue to complete the last inference after reaching the iteration limit or request time limit | |||
| class Config: | |||
| @@ -62,34 +65,42 @@ class AgentExecutor: | |||
| def _init_agent(self) -> Union[BaseSingleActionAgent, BaseMultiActionAgent]: | |||
| if self.configuration.strategy == PlanningStrategy.REACT: | |||
| agent = AutoSummarizingStructuredChatAgent.from_llm_and_tools( | |||
| model_instance=self.configuration.model_instance, | |||
| model_config=self.configuration.model_config, | |||
| tools=self.configuration.tools, | |||
| output_parser=StructuredChatOutputParser(), | |||
| summary_model_instance=self.configuration.summary_model_instance | |||
| if self.configuration.summary_model_instance else None, | |||
| summary_model_config=self.configuration.summary_model_config | |||
| if self.configuration.summary_model_config else None, | |||
| agent_llm_callback=self.configuration.agent_llm_callback, | |||
| verbose=True | |||
| ) | |||
| elif self.configuration.strategy == PlanningStrategy.FUNCTION_CALL: | |||
| agent = AutoSummarizingOpenAIFunctionCallAgent.from_llm_and_tools( | |||
| model_instance=self.configuration.model_instance, | |||
| model_config=self.configuration.model_config, | |||
| tools=self.configuration.tools, | |||
| extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None, # used for read chat histories memory | |||
| summary_model_instance=self.configuration.summary_model_instance | |||
| if self.configuration.summary_model_instance else None, | |||
| extra_prompt_messages=prompt_messages_to_lc_messages(self.configuration.memory.get_history_prompt_messages()) | |||
| if self.configuration.memory else None, # used for read chat histories memory | |||
| summary_model_config=self.configuration.summary_model_config | |||
| if self.configuration.summary_model_config else None, | |||
| agent_llm_callback=self.configuration.agent_llm_callback, | |||
| verbose=True | |||
| ) | |||
| elif self.configuration.strategy == PlanningStrategy.ROUTER: | |||
| self.configuration.tools = [t for t in self.configuration.tools if isinstance(t, DatasetRetrieverTool) or isinstance(t, DatasetMultiRetrieverTool)] | |||
| self.configuration.tools = [t for t in self.configuration.tools | |||
| if isinstance(t, DatasetRetrieverTool) | |||
| or isinstance(t, DatasetMultiRetrieverTool)] | |||
| agent = MultiDatasetRouterAgent.from_llm_and_tools( | |||
| model_instance=self.configuration.model_instance, | |||
| model_config=self.configuration.model_config, | |||
| tools=self.configuration.tools, | |||
| extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None, | |||
| extra_prompt_messages=prompt_messages_to_lc_messages(self.configuration.memory.get_history_prompt_messages()) | |||
| if self.configuration.memory else None, | |||
| verbose=True | |||
| ) | |||
| elif self.configuration.strategy == PlanningStrategy.REACT_ROUTER: | |||
| self.configuration.tools = [t for t in self.configuration.tools if isinstance(t, DatasetRetrieverTool) or isinstance(t, DatasetMultiRetrieverTool)] | |||
| self.configuration.tools = [t for t in self.configuration.tools | |||
| if isinstance(t, DatasetRetrieverTool) | |||
| or isinstance(t, DatasetMultiRetrieverTool)] | |||
| agent = StructuredMultiDatasetRouterAgent.from_llm_and_tools( | |||
| model_instance=self.configuration.model_instance, | |||
| model_config=self.configuration.model_config, | |||
| tools=self.configuration.tools, | |||
| output_parser=StructuredChatOutputParser(), | |||
| verbose=True | |||
| @@ -104,11 +115,11 @@ class AgentExecutor: | |||
| def run(self, query: str) -> AgentExecuteResult: | |||
| moderation_result = moderation.check_moderation( | |||
| self.configuration.model_instance.model_provider, | |||
| self.configuration.model_config, | |||
| query | |||
| ) | |||
| if not moderation_result: | |||
| if moderation_result: | |||
| return AgentExecuteResult( | |||
| output="I apologize for any confusion, but I'm an AI assistant to be helpful, harmless, and honest.", | |||
| strategy=self.configuration.strategy, | |||
| @@ -118,7 +129,6 @@ class AgentExecutor: | |||
| agent_executor = LCAgentExecutor.from_agent_and_tools( | |||
| agent=self.agent, | |||
| tools=self.configuration.tools, | |||
| memory=self.configuration.memory, | |||
| max_iterations=self.configuration.max_iterations, | |||
| max_execution_time=self.configuration.max_execution_time, | |||
| early_stopping_method=self.configuration.early_stopping_method, | |||
| @@ -126,8 +136,8 @@ class AgentExecutor: | |||
| ) | |||
| try: | |||
| output = agent_executor.run(query) | |||
| except LLMError as ex: | |||
| output = agent_executor.run(input=query) | |||
| except InvokeError as ex: | |||
| raise ex | |||
| except Exception as ex: | |||
| logging.exception("agent_executor run failed") | |||
| @@ -0,0 +1,251 @@ | |||
| import json | |||
| import logging | |||
| from typing import cast | |||
| from core.agent.agent.agent_llm_callback import AgentLLMCallback | |||
| from core.app_runner.app_runner import AppRunner | |||
| from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler | |||
| from core.entities.application_entities import ApplicationGenerateEntity, PromptTemplateEntity, ModelConfigEntity | |||
| from core.application_queue_manager import ApplicationQueueManager | |||
| from core.features.agent_runner import AgentRunnerFeature | |||
| from core.memory.token_buffer_memory import TokenBufferMemory | |||
| from core.model_manager import ModelInstance | |||
| from core.model_runtime.entities.llm_entities import LLMUsage | |||
| from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel | |||
| from extensions.ext_database import db | |||
| from models.model import Conversation, Message, App, MessageChain, MessageAgentThought | |||
| logger = logging.getLogger(__name__) | |||
| class AgentApplicationRunner(AppRunner): | |||
| """ | |||
| Agent Application Runner | |||
| """ | |||
| def run(self, application_generate_entity: ApplicationGenerateEntity, | |||
| queue_manager: ApplicationQueueManager, | |||
| conversation: Conversation, | |||
| message: Message) -> None: | |||
| """ | |||
| Run agent application | |||
| :param application_generate_entity: application generate entity | |||
| :param queue_manager: application queue manager | |||
| :param conversation: conversation | |||
| :param message: message | |||
| :return: | |||
| """ | |||
| app_record = db.session.query(App).filter(App.id == application_generate_entity.app_id).first() | |||
| if not app_record: | |||
| raise ValueError(f"App not found") | |||
| app_orchestration_config = application_generate_entity.app_orchestration_config_entity | |||
| inputs = application_generate_entity.inputs | |||
| query = application_generate_entity.query | |||
| files = application_generate_entity.files | |||
| # Pre-calculate the number of tokens of the prompt messages, | |||
| # and return the rest number of tokens by model context token size limit and max token size limit. | |||
| # If the rest number of tokens is not enough, raise exception. | |||
| # Include: prompt template, inputs, query(optional), files(optional) | |||
| # Not Include: memory, external data, dataset context | |||
| self.get_pre_calculate_rest_tokens( | |||
| app_record=app_record, | |||
| model_config=app_orchestration_config.model_config, | |||
| prompt_template_entity=app_orchestration_config.prompt_template, | |||
| inputs=inputs, | |||
| files=files, | |||
| query=query | |||
| ) | |||
| memory = None | |||
| if application_generate_entity.conversation_id: | |||
| # get memory of conversation (read-only) | |||
| model_instance = ModelInstance( | |||
| provider_model_bundle=app_orchestration_config.model_config.provider_model_bundle, | |||
| model=app_orchestration_config.model_config.model | |||
| ) | |||
| memory = TokenBufferMemory( | |||
| conversation=conversation, | |||
| model_instance=model_instance | |||
| ) | |||
| # reorganize all inputs and template to prompt messages | |||
| # Include: prompt template, inputs, query(optional), files(optional) | |||
| # memory(optional) | |||
| prompt_messages, stop = self.originze_prompt_messages( | |||
| app_record=app_record, | |||
| model_config=app_orchestration_config.model_config, | |||
| prompt_template_entity=app_orchestration_config.prompt_template, | |||
| inputs=inputs, | |||
| files=files, | |||
| query=query, | |||
| context=None, | |||
| memory=memory | |||
| ) | |||
| # Create MessageChain | |||
| message_chain = self._init_message_chain( | |||
| message=message, | |||
| query=query | |||
| ) | |||
| # add agent callback to record agent thoughts | |||
| agent_callback = AgentLoopGatherCallbackHandler( | |||
| model_config=app_orchestration_config.model_config, | |||
| message=message, | |||
| queue_manager=queue_manager, | |||
| message_chain=message_chain | |||
| ) | |||
| # init LLM Callback | |||
| agent_llm_callback = AgentLLMCallback( | |||
| agent_callback=agent_callback | |||
| ) | |||
| agent_runner = AgentRunnerFeature( | |||
| tenant_id=application_generate_entity.tenant_id, | |||
| app_orchestration_config=app_orchestration_config, | |||
| model_config=app_orchestration_config.model_config, | |||
| config=app_orchestration_config.agent, | |||
| queue_manager=queue_manager, | |||
| message=message, | |||
| user_id=application_generate_entity.user_id, | |||
| agent_llm_callback=agent_llm_callback, | |||
| callback=agent_callback, | |||
| memory=memory | |||
| ) | |||
| # agent run | |||
| result = agent_runner.run( | |||
| query=query, | |||
| invoke_from=application_generate_entity.invoke_from | |||
| ) | |||
| if result: | |||
| self._save_message_chain( | |||
| message_chain=message_chain, | |||
| output_text=result | |||
| ) | |||
| if (result | |||
| and app_orchestration_config.prompt_template.prompt_type == PromptTemplateEntity.PromptType.SIMPLE | |||
| and app_orchestration_config.prompt_template.simple_prompt_template | |||
| ): | |||
| # Direct output if agent result exists and has pre prompt | |||
| self.direct_output( | |||
| queue_manager=queue_manager, | |||
| app_orchestration_config=app_orchestration_config, | |||
| prompt_messages=prompt_messages, | |||
| stream=application_generate_entity.stream, | |||
| text=result, | |||
| usage=self._get_usage_of_all_agent_thoughts( | |||
| model_config=app_orchestration_config.model_config, | |||
| message=message | |||
| ) | |||
| ) | |||
| else: | |||
| # As normal LLM run, agent result as context | |||
| context = result | |||
| # reorganize all inputs and template to prompt messages | |||
| # Include: prompt template, inputs, query(optional), files(optional) | |||
| # memory(optional), external data, dataset context(optional) | |||
| prompt_messages, stop = self.originze_prompt_messages( | |||
| app_record=app_record, | |||
| model_config=app_orchestration_config.model_config, | |||
| prompt_template_entity=app_orchestration_config.prompt_template, | |||
| inputs=inputs, | |||
| files=files, | |||
| query=query, | |||
| context=context, | |||
| memory=memory | |||
| ) | |||
| # Re-calculate the max tokens if sum(prompt_token + max_tokens) over model token limit | |||
| self.recale_llm_max_tokens( | |||
| model_config=app_orchestration_config.model_config, | |||
| prompt_messages=prompt_messages | |||
| ) | |||
| # Invoke model | |||
| model_instance = ModelInstance( | |||
| provider_model_bundle=app_orchestration_config.model_config.provider_model_bundle, | |||
| model=app_orchestration_config.model_config.model | |||
| ) | |||
| invoke_result = model_instance.invoke_llm( | |||
| prompt_messages=prompt_messages, | |||
| model_parameters=app_orchestration_config.model_config.parameters, | |||
| stop=stop, | |||
| stream=application_generate_entity.stream, | |||
| user=application_generate_entity.user_id, | |||
| ) | |||
| # handle invoke result | |||
| self._handle_invoke_result( | |||
| invoke_result=invoke_result, | |||
| queue_manager=queue_manager, | |||
| stream=application_generate_entity.stream | |||
| ) | |||
| def _init_message_chain(self, message: Message, query: str) -> MessageChain: | |||
| """ | |||
| Init MessageChain | |||
| :param message: message | |||
| :param query: query | |||
| :return: | |||
| """ | |||
| message_chain = MessageChain( | |||
| message_id=message.id, | |||
| type="AgentExecutor", | |||
| input=json.dumps({ | |||
| "input": query | |||
| }) | |||
| ) | |||
| db.session.add(message_chain) | |||
| db.session.commit() | |||
| return message_chain | |||
| def _save_message_chain(self, message_chain: MessageChain, output_text: str) -> None: | |||
| """ | |||
| Save MessageChain | |||
| :param message_chain: message chain | |||
| :param output_text: output text | |||
| :return: | |||
| """ | |||
| message_chain.output = json.dumps({ | |||
| "output": output_text | |||
| }) | |||
| db.session.commit() | |||
| def _get_usage_of_all_agent_thoughts(self, model_config: ModelConfigEntity, | |||
| message: Message) -> LLMUsage: | |||
| """ | |||
| Get usage of all agent thoughts | |||
| :param model_config: model config | |||
| :param message: message | |||
| :return: | |||
| """ | |||
| agent_thoughts = (db.session.query(MessageAgentThought) | |||
| .filter(MessageAgentThought.message_id == message.id).all()) | |||
| all_message_tokens = 0 | |||
| all_answer_tokens = 0 | |||
| for agent_thought in agent_thoughts: | |||
| all_message_tokens += agent_thought.message_tokens | |||
| all_answer_tokens += agent_thought.answer_tokens | |||
| model_type_instance = model_config.provider_model_bundle.model_type_instance | |||
| model_type_instance = cast(LargeLanguageModel, model_type_instance) | |||
| return model_type_instance._calc_response_usage( | |||
| model_config.model, | |||
| model_config.credentials, | |||
| all_message_tokens, | |||
| all_answer_tokens | |||
| ) | |||
| @@ -0,0 +1,267 @@ | |||
| import time | |||
| from typing import cast, Optional, List, Tuple, Generator, Union | |||
| from core.application_queue_manager import ApplicationQueueManager | |||
| from core.entities.application_entities import ModelConfigEntity, PromptTemplateEntity, AppOrchestrationConfigEntity | |||
| from core.file.file_obj import FileObj | |||
| from core.memory.token_buffer_memory import TokenBufferMemory | |||
| from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage | |||
| from core.model_runtime.entities.message_entities import PromptMessage, AssistantPromptMessage | |||
| from core.model_runtime.entities.model_entities import ModelPropertyKey | |||
| from core.model_runtime.errors.invoke import InvokeBadRequestError | |||
| from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel | |||
| from core.prompt.prompt_transform import PromptTransform | |||
| from models.model import App | |||
| class AppRunner: | |||
| def get_pre_calculate_rest_tokens(self, app_record: App, | |||
| model_config: ModelConfigEntity, | |||
| prompt_template_entity: PromptTemplateEntity, | |||
| inputs: dict[str, str], | |||
| files: list[FileObj], | |||
| query: Optional[str] = None) -> int: | |||
| """ | |||
| Get pre calculate rest tokens | |||
| :param app_record: app record | |||
| :param model_config: model config entity | |||
| :param prompt_template_entity: prompt template entity | |||
| :param inputs: inputs | |||
| :param files: files | |||
| :param query: query | |||
| :return: | |||
| """ | |||
| model_type_instance = model_config.provider_model_bundle.model_type_instance | |||
| model_type_instance = cast(LargeLanguageModel, model_type_instance) | |||
| model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE) | |||
| max_tokens = 0 | |||
| for parameter_rule in model_config.model_schema.parameter_rules: | |||
| if (parameter_rule.name == 'max_tokens' | |||
| or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')): | |||
| max_tokens = (model_config.parameters.get(parameter_rule.name) | |||
| or model_config.parameters.get(parameter_rule.use_template)) or 0 | |||
| if model_context_tokens is None: | |||
| return -1 | |||
| if max_tokens is None: | |||
| max_tokens = 0 | |||
| # get prompt messages without memory and context | |||
| prompt_messages, stop = self.originze_prompt_messages( | |||
| app_record=app_record, | |||
| model_config=model_config, | |||
| prompt_template_entity=prompt_template_entity, | |||
| inputs=inputs, | |||
| files=files, | |||
| query=query | |||
| ) | |||
| prompt_tokens = model_type_instance.get_num_tokens( | |||
| model_config.model, | |||
| model_config.credentials, | |||
| prompt_messages | |||
| ) | |||
| rest_tokens = model_context_tokens - max_tokens - prompt_tokens | |||
| if rest_tokens < 0: | |||
| raise InvokeBadRequestError("Query or prefix prompt is too long, you can reduce the prefix prompt, " | |||
| "or shrink the max token, or switch to a llm with a larger token limit size.") | |||
| return rest_tokens | |||
| def recale_llm_max_tokens(self, model_config: ModelConfigEntity, | |||
| prompt_messages: List[PromptMessage]): | |||
| # recalc max_tokens if sum(prompt_token + max_tokens) over model token limit | |||
| model_type_instance = model_config.provider_model_bundle.model_type_instance | |||
| model_type_instance = cast(LargeLanguageModel, model_type_instance) | |||
| model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE) | |||
| max_tokens = 0 | |||
| for parameter_rule in model_config.model_schema.parameter_rules: | |||
| if (parameter_rule.name == 'max_tokens' | |||
| or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')): | |||
| max_tokens = (model_config.parameters.get(parameter_rule.name) | |||
| or model_config.parameters.get(parameter_rule.use_template)) or 0 | |||
| if model_context_tokens is None: | |||
| return -1 | |||
| if max_tokens is None: | |||
| max_tokens = 0 | |||
| prompt_tokens = model_type_instance.get_num_tokens( | |||
| model_config.model, | |||
| model_config.credentials, | |||
| prompt_messages | |||
| ) | |||
| if prompt_tokens + max_tokens > model_context_tokens: | |||
| max_tokens = max(model_context_tokens - prompt_tokens, 16) | |||
| for parameter_rule in model_config.model_schema.parameter_rules: | |||
| if (parameter_rule.name == 'max_tokens' | |||
| or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')): | |||
| model_config.parameters[parameter_rule.name] = max_tokens | |||
| def originze_prompt_messages(self, app_record: App, | |||
| model_config: ModelConfigEntity, | |||
| prompt_template_entity: PromptTemplateEntity, | |||
| inputs: dict[str, str], | |||
| files: list[FileObj], | |||
| query: Optional[str] = None, | |||
| context: Optional[str] = None, | |||
| memory: Optional[TokenBufferMemory] = None) \ | |||
| -> Tuple[List[PromptMessage], Optional[List[str]]]: | |||
| """ | |||
| Organize prompt messages | |||
| :param context: | |||
| :param app_record: app record | |||
| :param model_config: model config entity | |||
| :param prompt_template_entity: prompt template entity | |||
| :param inputs: inputs | |||
| :param files: files | |||
| :param query: query | |||
| :param memory: memory | |||
| :return: | |||
| """ | |||
| prompt_transform = PromptTransform() | |||
| # get prompt without memory and context | |||
| if prompt_template_entity.prompt_type == PromptTemplateEntity.PromptType.SIMPLE: | |||
| prompt_messages, stop = prompt_transform.get_prompt( | |||
| app_mode=app_record.mode, | |||
| prompt_template_entity=prompt_template_entity, | |||
| inputs=inputs, | |||
| query=query if query else '', | |||
| files=files, | |||
| context=context, | |||
| memory=memory, | |||
| model_config=model_config | |||
| ) | |||
| else: | |||
| prompt_messages = prompt_transform.get_advanced_prompt( | |||
| app_mode=app_record.mode, | |||
| prompt_template_entity=prompt_template_entity, | |||
| inputs=inputs, | |||
| query=query, | |||
| files=files, | |||
| context=context, | |||
| memory=memory, | |||
| model_config=model_config | |||
| ) | |||
| stop = model_config.stop | |||
| return prompt_messages, stop | |||
| def direct_output(self, queue_manager: ApplicationQueueManager, | |||
| app_orchestration_config: AppOrchestrationConfigEntity, | |||
| prompt_messages: list, | |||
| text: str, | |||
| stream: bool, | |||
| usage: Optional[LLMUsage] = None) -> None: | |||
| """ | |||
| Direct output | |||
| :param queue_manager: application queue manager | |||
| :param app_orchestration_config: app orchestration config | |||
| :param prompt_messages: prompt messages | |||
| :param text: text | |||
| :param stream: stream | |||
| :param usage: usage | |||
| :return: | |||
| """ | |||
| if stream: | |||
| index = 0 | |||
| for token in text: | |||
| queue_manager.publish_chunk_message(LLMResultChunk( | |||
| model=app_orchestration_config.model_config.model, | |||
| prompt_messages=prompt_messages, | |||
| delta=LLMResultChunkDelta( | |||
| index=index, | |||
| message=AssistantPromptMessage(content=token) | |||
| ) | |||
| )) | |||
| index += 1 | |||
| time.sleep(0.01) | |||
| queue_manager.publish_message_end( | |||
| llm_result=LLMResult( | |||
| model=app_orchestration_config.model_config.model, | |||
| prompt_messages=prompt_messages, | |||
| message=AssistantPromptMessage(content=text), | |||
| usage=usage if usage else LLMUsage.empty_usage() | |||
| ) | |||
| ) | |||
| def _handle_invoke_result(self, invoke_result: Union[LLMResult, Generator], | |||
| queue_manager: ApplicationQueueManager, | |||
| stream: bool) -> None: | |||
| """ | |||
| Handle invoke result | |||
| :param invoke_result: invoke result | |||
| :param queue_manager: application queue manager | |||
| :param stream: stream | |||
| :return: | |||
| """ | |||
| if not stream: | |||
| self._handle_invoke_result_direct( | |||
| invoke_result=invoke_result, | |||
| queue_manager=queue_manager | |||
| ) | |||
| else: | |||
| self._handle_invoke_result_stream( | |||
| invoke_result=invoke_result, | |||
| queue_manager=queue_manager | |||
| ) | |||
| def _handle_invoke_result_direct(self, invoke_result: LLMResult, | |||
| queue_manager: ApplicationQueueManager) -> None: | |||
| """ | |||
| Handle invoke result direct | |||
| :param invoke_result: invoke result | |||
| :param queue_manager: application queue manager | |||
| :return: | |||
| """ | |||
| queue_manager.publish_message_end( | |||
| llm_result=invoke_result | |||
| ) | |||
| def _handle_invoke_result_stream(self, invoke_result: Generator, | |||
| queue_manager: ApplicationQueueManager) -> None: | |||
| """ | |||
| Handle invoke result | |||
| :param invoke_result: invoke result | |||
| :param queue_manager: application queue manager | |||
| :return: | |||
| """ | |||
| model = None | |||
| prompt_messages = [] | |||
| text = '' | |||
| usage = None | |||
| for result in invoke_result: | |||
| queue_manager.publish_chunk_message(result) | |||
| text += result.delta.message.content | |||
| if not model: | |||
| model = result.model | |||
| if not prompt_messages: | |||
| prompt_messages = result.prompt_messages | |||
| if not usage and result.delta.usage: | |||
| usage = result.delta.usage | |||
| llm_result = LLMResult( | |||
| model=model, | |||
| prompt_messages=prompt_messages, | |||
| message=AssistantPromptMessage(content=text), | |||
| usage=usage | |||
| ) | |||
| queue_manager.publish_message_end( | |||
| llm_result=llm_result | |||
| ) | |||
| @@ -0,0 +1,363 @@ | |||
| import logging | |||
| from typing import Tuple, Optional | |||
| from core.app_runner.app_runner import AppRunner | |||
| from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler | |||
| from core.entities.application_entities import ApplicationGenerateEntity, ModelConfigEntity, \ | |||
| AppOrchestrationConfigEntity, InvokeFrom, ExternalDataVariableEntity, DatasetEntity | |||
| from core.application_queue_manager import ApplicationQueueManager | |||
| from core.features.annotation_reply import AnnotationReplyFeature | |||
| from core.features.dataset_retrieval import DatasetRetrievalFeature | |||
| from core.features.external_data_fetch import ExternalDataFetchFeature | |||
| from core.features.hosting_moderation import HostingModerationFeature | |||
| from core.features.moderation import ModerationFeature | |||
| from core.memory.token_buffer_memory import TokenBufferMemory | |||
| from core.model_manager import ModelInstance | |||
| from core.model_runtime.entities.message_entities import PromptMessage | |||
| from core.moderation.base import ModerationException | |||
| from core.prompt.prompt_transform import AppMode | |||
| from extensions.ext_database import db | |||
| from models.model import Conversation, Message, App, MessageAnnotation | |||
| logger = logging.getLogger(__name__) | |||
| class BasicApplicationRunner(AppRunner): | |||
| """ | |||
| Basic Application Runner | |||
| """ | |||
| def run(self, application_generate_entity: ApplicationGenerateEntity, | |||
| queue_manager: ApplicationQueueManager, | |||
| conversation: Conversation, | |||
| message: Message) -> None: | |||
| """ | |||
| Run application | |||
| :param application_generate_entity: application generate entity | |||
| :param queue_manager: application queue manager | |||
| :param conversation: conversation | |||
| :param message: message | |||
| :return: | |||
| """ | |||
| app_record = db.session.query(App).filter(App.id == application_generate_entity.app_id).first() | |||
| if not app_record: | |||
| raise ValueError(f"App not found") | |||
| app_orchestration_config = application_generate_entity.app_orchestration_config_entity | |||
| inputs = application_generate_entity.inputs | |||
| query = application_generate_entity.query | |||
| files = application_generate_entity.files | |||
| # Pre-calculate the number of tokens of the prompt messages, | |||
| # and return the rest number of tokens by model context token size limit and max token size limit. | |||
| # If the rest number of tokens is not enough, raise exception. | |||
| # Include: prompt template, inputs, query(optional), files(optional) | |||
| # Not Include: memory, external data, dataset context | |||
| self.get_pre_calculate_rest_tokens( | |||
| app_record=app_record, | |||
| model_config=app_orchestration_config.model_config, | |||
| prompt_template_entity=app_orchestration_config.prompt_template, | |||
| inputs=inputs, | |||
| files=files, | |||
| query=query | |||
| ) | |||
| memory = None | |||
| if application_generate_entity.conversation_id: | |||
| # get memory of conversation (read-only) | |||
| model_instance = ModelInstance( | |||
| provider_model_bundle=app_orchestration_config.model_config.provider_model_bundle, | |||
| model=app_orchestration_config.model_config.model | |||
| ) | |||
| memory = TokenBufferMemory( | |||
| conversation=conversation, | |||
| model_instance=model_instance | |||
| ) | |||
| # organize all inputs and template to prompt messages | |||
| # Include: prompt template, inputs, query(optional), files(optional) | |||
| # memory(optional) | |||
| prompt_messages, stop = self.originze_prompt_messages( | |||
| app_record=app_record, | |||
| model_config=app_orchestration_config.model_config, | |||
| prompt_template_entity=app_orchestration_config.prompt_template, | |||
| inputs=inputs, | |||
| files=files, | |||
| query=query, | |||
| memory=memory | |||
| ) | |||
| # moderation | |||
| try: | |||
| # process sensitive_word_avoidance | |||
| _, inputs, query = self.moderation_for_inputs( | |||
| app_id=app_record.id, | |||
| tenant_id=application_generate_entity.tenant_id, | |||
| app_orchestration_config_entity=app_orchestration_config, | |||
| inputs=inputs, | |||
| query=query, | |||
| ) | |||
| except ModerationException as e: | |||
| self.direct_output( | |||
| queue_manager=queue_manager, | |||
| app_orchestration_config=app_orchestration_config, | |||
| prompt_messages=prompt_messages, | |||
| text=str(e), | |||
| stream=application_generate_entity.stream | |||
| ) | |||
| return | |||
| if query: | |||
| # annotation reply | |||
| annotation_reply = self.query_app_annotations_to_reply( | |||
| app_record=app_record, | |||
| message=message, | |||
| query=query, | |||
| user_id=application_generate_entity.user_id, | |||
| invoke_from=application_generate_entity.invoke_from | |||
| ) | |||
| if annotation_reply: | |||
| queue_manager.publish_annotation_reply( | |||
| message_annotation_id=annotation_reply.id | |||
| ) | |||
| self.direct_output( | |||
| queue_manager=queue_manager, | |||
| app_orchestration_config=app_orchestration_config, | |||
| prompt_messages=prompt_messages, | |||
| text=annotation_reply.content, | |||
| stream=application_generate_entity.stream | |||
| ) | |||
| return | |||
| # fill in variable inputs from external data tools if exists | |||
| external_data_tools = app_orchestration_config.external_data_variables | |||
| if external_data_tools: | |||
| inputs = self.fill_in_inputs_from_external_data_tools( | |||
| tenant_id=app_record.tenant_id, | |||
| app_id=app_record.id, | |||
| external_data_tools=external_data_tools, | |||
| inputs=inputs, | |||
| query=query | |||
| ) | |||
| # get context from datasets | |||
| context = None | |||
| if app_orchestration_config.dataset: | |||
| context = self.retrieve_dataset_context( | |||
| tenant_id=app_record.tenant_id, | |||
| app_record=app_record, | |||
| queue_manager=queue_manager, | |||
| model_config=app_orchestration_config.model_config, | |||
| show_retrieve_source=app_orchestration_config.show_retrieve_source, | |||
| dataset_config=app_orchestration_config.dataset, | |||
| message=message, | |||
| inputs=inputs, | |||
| query=query, | |||
| user_id=application_generate_entity.user_id, | |||
| invoke_from=application_generate_entity.invoke_from, | |||
| memory=memory | |||
| ) | |||
| # reorganize all inputs and template to prompt messages | |||
| # Include: prompt template, inputs, query(optional), files(optional) | |||
| # memory(optional), external data, dataset context(optional) | |||
| prompt_messages, stop = self.originze_prompt_messages( | |||
| app_record=app_record, | |||
| model_config=app_orchestration_config.model_config, | |||
| prompt_template_entity=app_orchestration_config.prompt_template, | |||
| inputs=inputs, | |||
| files=files, | |||
| query=query, | |||
| context=context, | |||
| memory=memory | |||
| ) | |||
| # check hosting moderation | |||
| hosting_moderation_result = self.check_hosting_moderation( | |||
| application_generate_entity=application_generate_entity, | |||
| queue_manager=queue_manager, | |||
| prompt_messages=prompt_messages | |||
| ) | |||
| if hosting_moderation_result: | |||
| return | |||
| # Re-calculate the max tokens if sum(prompt_token + max_tokens) over model token limit | |||
| self.recale_llm_max_tokens( | |||
| model_config=app_orchestration_config.model_config, | |||
| prompt_messages=prompt_messages | |||
| ) | |||
| # Invoke model | |||
| model_instance = ModelInstance( | |||
| provider_model_bundle=app_orchestration_config.model_config.provider_model_bundle, | |||
| model=app_orchestration_config.model_config.model | |||
| ) | |||
| invoke_result = model_instance.invoke_llm( | |||
| prompt_messages=prompt_messages, | |||
| model_parameters=app_orchestration_config.model_config.parameters, | |||
| stop=stop, | |||
| stream=application_generate_entity.stream, | |||
| user=application_generate_entity.user_id, | |||
| ) | |||
| # handle invoke result | |||
| self._handle_invoke_result( | |||
| invoke_result=invoke_result, | |||
| queue_manager=queue_manager, | |||
| stream=application_generate_entity.stream | |||
| ) | |||
| def moderation_for_inputs(self, app_id: str, | |||
| tenant_id: str, | |||
| app_orchestration_config_entity: AppOrchestrationConfigEntity, | |||
| inputs: dict, | |||
| query: str) -> Tuple[bool, dict, str]: | |||
| """ | |||
| Process sensitive_word_avoidance. | |||
| :param app_id: app id | |||
| :param tenant_id: tenant id | |||
| :param app_orchestration_config_entity: app orchestration config entity | |||
| :param inputs: inputs | |||
| :param query: query | |||
| :return: | |||
| """ | |||
| moderation_feature = ModerationFeature() | |||
| return moderation_feature.check( | |||
| app_id=app_id, | |||
| tenant_id=tenant_id, | |||
| app_orchestration_config_entity=app_orchestration_config_entity, | |||
| inputs=inputs, | |||
| query=query, | |||
| ) | |||
| def query_app_annotations_to_reply(self, app_record: App, | |||
| message: Message, | |||
| query: str, | |||
| user_id: str, | |||
| invoke_from: InvokeFrom) -> Optional[MessageAnnotation]: | |||
| """ | |||
| Query app annotations to reply | |||
| :param app_record: app record | |||
| :param message: message | |||
| :param query: query | |||
| :param user_id: user id | |||
| :param invoke_from: invoke from | |||
| :return: | |||
| """ | |||
| annotation_reply_feature = AnnotationReplyFeature() | |||
| return annotation_reply_feature.query( | |||
| app_record=app_record, | |||
| message=message, | |||
| query=query, | |||
| user_id=user_id, | |||
| invoke_from=invoke_from | |||
| ) | |||
| def fill_in_inputs_from_external_data_tools(self, tenant_id: str, | |||
| app_id: str, | |||
| external_data_tools: list[ExternalDataVariableEntity], | |||
| inputs: dict, | |||
| query: str) -> dict: | |||
| """ | |||
| Fill in variable inputs from external data tools if exists. | |||
| :param tenant_id: workspace id | |||
| :param app_id: app id | |||
| :param external_data_tools: external data tools configs | |||
| :param inputs: the inputs | |||
| :param query: the query | |||
| :return: the filled inputs | |||
| """ | |||
| external_data_fetch_feature = ExternalDataFetchFeature() | |||
| return external_data_fetch_feature.fetch( | |||
| tenant_id=tenant_id, | |||
| app_id=app_id, | |||
| external_data_tools=external_data_tools, | |||
| inputs=inputs, | |||
| query=query | |||
| ) | |||
| def retrieve_dataset_context(self, tenant_id: str, | |||
| app_record: App, | |||
| queue_manager: ApplicationQueueManager, | |||
| model_config: ModelConfigEntity, | |||
| dataset_config: DatasetEntity, | |||
| show_retrieve_source: bool, | |||
| message: Message, | |||
| inputs: dict, | |||
| query: str, | |||
| user_id: str, | |||
| invoke_from: InvokeFrom, | |||
| memory: Optional[TokenBufferMemory] = None) -> Optional[str]: | |||
| """ | |||
| Retrieve dataset context | |||
| :param tenant_id: tenant id | |||
| :param app_record: app record | |||
| :param queue_manager: queue manager | |||
| :param model_config: model config | |||
| :param dataset_config: dataset config | |||
| :param show_retrieve_source: show retrieve source | |||
| :param message: message | |||
| :param inputs: inputs | |||
| :param query: query | |||
| :param user_id: user id | |||
| :param invoke_from: invoke from | |||
| :param memory: memory | |||
| :return: | |||
| """ | |||
| hit_callback = DatasetIndexToolCallbackHandler( | |||
| queue_manager, | |||
| app_record.id, | |||
| message.id, | |||
| user_id, | |||
| invoke_from | |||
| ) | |||
| if (app_record.mode == AppMode.COMPLETION.value and dataset_config | |||
| and dataset_config.retrieve_config.query_variable): | |||
| query = inputs.get(dataset_config.retrieve_config.query_variable, "") | |||
| dataset_retrieval = DatasetRetrievalFeature() | |||
| return dataset_retrieval.retrieve( | |||
| tenant_id=tenant_id, | |||
| model_config=model_config, | |||
| config=dataset_config, | |||
| query=query, | |||
| invoke_from=invoke_from, | |||
| show_retrieve_source=show_retrieve_source, | |||
| hit_callback=hit_callback, | |||
| memory=memory | |||
| ) | |||
| def check_hosting_moderation(self, application_generate_entity: ApplicationGenerateEntity, | |||
| queue_manager: ApplicationQueueManager, | |||
| prompt_messages: list[PromptMessage]) -> bool: | |||
| """ | |||
| Check hosting moderation | |||
| :param application_generate_entity: application generate entity | |||
| :param queue_manager: queue manager | |||
| :param prompt_messages: prompt messages | |||
| :return: | |||
| """ | |||
| hosting_moderation_feature = HostingModerationFeature() | |||
| moderation_result = hosting_moderation_feature.check( | |||
| application_generate_entity=application_generate_entity, | |||
| prompt_messages=prompt_messages | |||
| ) | |||
| if moderation_result: | |||
| self.direct_output( | |||
| queue_manager=queue_manager, | |||
| app_orchestration_config=application_generate_entity.app_orchestration_config_entity, | |||
| prompt_messages=prompt_messages, | |||
| text="I apologize for any confusion, " \ | |||
| "but I'm an AI assistant to be helpful, harmless, and honest.", | |||
| stream=application_generate_entity.stream | |||
| ) | |||
| return moderation_result | |||
| @@ -0,0 +1,483 @@ | |||
| import json | |||
| import logging | |||
| import time | |||
| from typing import Union, Generator, cast, Optional | |||
| from pydantic import BaseModel | |||
| from core.app_runner.moderation_handler import OutputModerationHandler, ModerationRule | |||
| from core.entities.application_entities import ApplicationGenerateEntity | |||
| from core.application_queue_manager import ApplicationQueueManager | |||
| from core.entities.queue_entities import QueueErrorEvent, QueueStopEvent, QueueMessageEndEvent, \ | |||
| QueueRetrieverResourcesEvent, QueueAgentThoughtEvent, QueuePingEvent, QueueMessageEvent, QueueMessageReplaceEvent, \ | |||
| AnnotationReplyEvent | |||
| from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage, LLMResultChunk, LLMResultChunkDelta | |||
| from core.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessageRole, \ | |||
| TextPromptMessageContent, PromptMessageContentType, ImagePromptMessageContent, PromptMessage | |||
| from core.model_runtime.errors.invoke import InvokeError, InvokeAuthorizationError | |||
| from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel | |||
| from core.prompt.prompt_template import PromptTemplateParser | |||
| from events.message_event import message_was_created | |||
| from extensions.ext_database import db | |||
| from models.model import Message, Conversation, MessageAgentThought | |||
| from services.annotation_service import AppAnnotationService | |||
| logger = logging.getLogger(__name__) | |||
| class TaskState(BaseModel): | |||
| """ | |||
| TaskState entity | |||
| """ | |||
| llm_result: LLMResult | |||
| metadata: dict = {} | |||
| class GenerateTaskPipeline: | |||
| """ | |||
| GenerateTaskPipeline is a class that generate stream output and state management for Application. | |||
| """ | |||
| def __init__(self, application_generate_entity: ApplicationGenerateEntity, | |||
| queue_manager: ApplicationQueueManager, | |||
| conversation: Conversation, | |||
| message: Message) -> None: | |||
| """ | |||
| Initialize GenerateTaskPipeline. | |||
| :param application_generate_entity: application generate entity | |||
| :param queue_manager: queue manager | |||
| :param conversation: conversation | |||
| :param message: message | |||
| """ | |||
| self._application_generate_entity = application_generate_entity | |||
| self._queue_manager = queue_manager | |||
| self._conversation = conversation | |||
| self._message = message | |||
| self._task_state = TaskState( | |||
| llm_result=LLMResult( | |||
| model=self._application_generate_entity.app_orchestration_config_entity.model_config.model, | |||
| prompt_messages=[], | |||
| message=AssistantPromptMessage(content=""), | |||
| usage=LLMUsage.empty_usage() | |||
| ) | |||
| ) | |||
| self._start_at = time.perf_counter() | |||
| self._output_moderation_handler = self._init_output_moderation() | |||
| def process(self, stream: bool) -> Union[dict, Generator]: | |||
| """ | |||
| Process generate task pipeline. | |||
| :return: | |||
| """ | |||
| if stream: | |||
| return self._process_stream_response() | |||
| else: | |||
| return self._process_blocking_response() | |||
| def _process_blocking_response(self) -> dict: | |||
| """ | |||
| Process blocking response. | |||
| :return: | |||
| """ | |||
| for queue_message in self._queue_manager.listen(): | |||
| event = queue_message.event | |||
| if isinstance(event, QueueErrorEvent): | |||
| raise self._handle_error(event) | |||
| elif isinstance(event, QueueRetrieverResourcesEvent): | |||
| self._task_state.metadata['retriever_resources'] = event.retriever_resources | |||
| elif isinstance(event, AnnotationReplyEvent): | |||
| annotation = AppAnnotationService.get_annotation_by_id(event.message_annotation_id) | |||
| if annotation: | |||
| account = annotation.account | |||
| self._task_state.metadata['annotation_reply'] = { | |||
| 'id': annotation.id, | |||
| 'account': { | |||
| 'id': annotation.account_id, | |||
| 'name': account.name if account else 'Dify user' | |||
| } | |||
| } | |||
| self._task_state.llm_result.message.content = annotation.content | |||
| elif isinstance(event, (QueueStopEvent, QueueMessageEndEvent)): | |||
| if isinstance(event, QueueMessageEndEvent): | |||
| self._task_state.llm_result = event.llm_result | |||
| else: | |||
| model_config = self._application_generate_entity.app_orchestration_config_entity.model_config | |||
| model = model_config.model | |||
| model_type_instance = model_config.provider_model_bundle.model_type_instance | |||
| model_type_instance = cast(LargeLanguageModel, model_type_instance) | |||
| # calculate num tokens | |||
| prompt_tokens = 0 | |||
| if event.stopped_by != QueueStopEvent.StopBy.ANNOTATION_REPLY: | |||
| prompt_tokens = model_type_instance.get_num_tokens( | |||
| model, | |||
| model_config.credentials, | |||
| self._task_state.llm_result.prompt_messages | |||
| ) | |||
| completion_tokens = 0 | |||
| if event.stopped_by == QueueStopEvent.StopBy.USER_MANUAL: | |||
| completion_tokens = model_type_instance.get_num_tokens( | |||
| model, | |||
| model_config.credentials, | |||
| [self._task_state.llm_result.message] | |||
| ) | |||
| credentials = model_config.credentials | |||
| # transform usage | |||
| self._task_state.llm_result.usage = model_type_instance._calc_response_usage( | |||
| model, | |||
| credentials, | |||
| prompt_tokens, | |||
| completion_tokens | |||
| ) | |||
| # response moderation | |||
| if self._output_moderation_handler: | |||
| self._output_moderation_handler.stop_thread() | |||
| self._task_state.llm_result.message.content = self._output_moderation_handler.moderation_completion( | |||
| completion=self._task_state.llm_result.message.content, | |||
| public_event=False | |||
| ) | |||
| # Save message | |||
| self._save_message(event.llm_result) | |||
| response = { | |||
| 'event': 'message', | |||
| 'task_id': self._application_generate_entity.task_id, | |||
| 'id': self._message.id, | |||
| 'mode': self._conversation.mode, | |||
| 'answer': event.llm_result.message.content, | |||
| 'metadata': {}, | |||
| 'created_at': int(self._message.created_at.timestamp()) | |||
| } | |||
| if self._conversation.mode == 'chat': | |||
| response['conversation_id'] = self._conversation.id | |||
| if self._task_state.metadata: | |||
| response['metadata'] = self._task_state.metadata | |||
| return response | |||
| else: | |||
| continue | |||
| def _process_stream_response(self) -> Generator: | |||
| """ | |||
| Process stream response. | |||
| :return: | |||
| """ | |||
| for message in self._queue_manager.listen(): | |||
| event = message.event | |||
| if isinstance(event, QueueErrorEvent): | |||
| raise self._handle_error(event) | |||
| elif isinstance(event, (QueueStopEvent, QueueMessageEndEvent)): | |||
| if isinstance(event, QueueMessageEndEvent): | |||
| self._task_state.llm_result = event.llm_result | |||
| else: | |||
| model_config = self._application_generate_entity.app_orchestration_config_entity.model_config | |||
| model = model_config.model | |||
| model_type_instance = model_config.provider_model_bundle.model_type_instance | |||
| model_type_instance = cast(LargeLanguageModel, model_type_instance) | |||
| # calculate num tokens | |||
| prompt_tokens = 0 | |||
| if event.stopped_by != QueueStopEvent.StopBy.ANNOTATION_REPLY: | |||
| prompt_tokens = model_type_instance.get_num_tokens( | |||
| model, | |||
| model_config.credentials, | |||
| self._task_state.llm_result.prompt_messages | |||
| ) | |||
| completion_tokens = 0 | |||
| if event.stopped_by == QueueStopEvent.StopBy.USER_MANUAL: | |||
| completion_tokens = model_type_instance.get_num_tokens( | |||
| model, | |||
| model_config.credentials, | |||
| [self._task_state.llm_result.message] | |||
| ) | |||
| credentials = model_config.credentials | |||
| # transform usage | |||
| self._task_state.llm_result.usage = model_type_instance._calc_response_usage( | |||
| model, | |||
| credentials, | |||
| prompt_tokens, | |||
| completion_tokens | |||
| ) | |||
| # response moderation | |||
| if self._output_moderation_handler: | |||
| self._output_moderation_handler.stop_thread() | |||
| self._task_state.llm_result.message.content = self._output_moderation_handler.moderation_completion( | |||
| completion=self._task_state.llm_result.message.content, | |||
| public_event=False | |||
| ) | |||
| self._output_moderation_handler = None | |||
| replace_response = { | |||
| 'event': 'message_replace', | |||
| 'task_id': self._application_generate_entity.task_id, | |||
| 'message_id': self._message.id, | |||
| 'answer': self._task_state.llm_result.message.content, | |||
| 'created_at': int(self._message.created_at.timestamp()) | |||
| } | |||
| if self._conversation.mode == 'chat': | |||
| replace_response['conversation_id'] = self._conversation.id | |||
| yield self._yield_response(replace_response) | |||
| # Save message | |||
| self._save_message(self._task_state.llm_result) | |||
| response = { | |||
| 'event': 'message_end', | |||
| 'task_id': self._application_generate_entity.task_id, | |||
| 'id': self._message.id, | |||
| } | |||
| if self._conversation.mode == 'chat': | |||
| response['conversation_id'] = self._conversation.id | |||
| if self._task_state.metadata: | |||
| response['metadata'] = self._task_state.metadata | |||
| yield self._yield_response(response) | |||
| elif isinstance(event, QueueRetrieverResourcesEvent): | |||
| self._task_state.metadata['retriever_resources'] = event.retriever_resources | |||
| elif isinstance(event, AnnotationReplyEvent): | |||
| annotation = AppAnnotationService.get_annotation_by_id(event.message_annotation_id) | |||
| if annotation: | |||
| account = annotation.account | |||
| self._task_state.metadata['annotation_reply'] = { | |||
| 'id': annotation.id, | |||
| 'account': { | |||
| 'id': annotation.account_id, | |||
| 'name': account.name if account else 'Dify user' | |||
| } | |||
| } | |||
| self._task_state.llm_result.message.content = annotation.content | |||
| elif isinstance(event, QueueAgentThoughtEvent): | |||
| agent_thought = ( | |||
| db.session.query(MessageAgentThought) | |||
| .filter(MessageAgentThought.id == event.agent_thought_id) | |||
| .first() | |||
| ) | |||
| if agent_thought: | |||
| response = { | |||
| 'event': 'agent_thought', | |||
| 'id': agent_thought.id, | |||
| 'task_id': self._application_generate_entity.task_id, | |||
| 'message_id': self._message.id, | |||
| 'position': agent_thought.position, | |||
| 'thought': agent_thought.thought, | |||
| 'tool': agent_thought.tool, | |||
| 'tool_input': agent_thought.tool_input, | |||
| 'created_at': int(self._message.created_at.timestamp()) | |||
| } | |||
| if self._conversation.mode == 'chat': | |||
| response['conversation_id'] = self._conversation.id | |||
| yield self._yield_response(response) | |||
| elif isinstance(event, QueueMessageEvent): | |||
| chunk = event.chunk | |||
| delta_text = chunk.delta.message.content | |||
| if delta_text is None: | |||
| continue | |||
| if not self._task_state.llm_result.prompt_messages: | |||
| self._task_state.llm_result.prompt_messages = chunk.prompt_messages | |||
| if self._output_moderation_handler: | |||
| if self._output_moderation_handler.should_direct_output(): | |||
| # stop subscribe new token when output moderation should direct output | |||
| self._task_state.llm_result.message.content = self._output_moderation_handler.get_final_output() | |||
| self._queue_manager.publish_chunk_message(LLMResultChunk( | |||
| model=self._task_state.llm_result.model, | |||
| prompt_messages=self._task_state.llm_result.prompt_messages, | |||
| delta=LLMResultChunkDelta( | |||
| index=0, | |||
| message=AssistantPromptMessage(content=self._task_state.llm_result.message.content) | |||
| ) | |||
| )) | |||
| self._queue_manager.publish(QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION)) | |||
| continue | |||
| else: | |||
| self._output_moderation_handler.append_new_token(delta_text) | |||
| self._task_state.llm_result.message.content += delta_text | |||
| response = self._handle_chunk(delta_text) | |||
| yield self._yield_response(response) | |||
| elif isinstance(event, QueueMessageReplaceEvent): | |||
| response = { | |||
| 'event': 'message_replace', | |||
| 'task_id': self._application_generate_entity.task_id, | |||
| 'message_id': self._message.id, | |||
| 'answer': event.text, | |||
| 'created_at': int(self._message.created_at.timestamp()) | |||
| } | |||
| if self._conversation.mode == 'chat': | |||
| response['conversation_id'] = self._conversation.id | |||
| yield self._yield_response(response) | |||
| elif isinstance(event, QueuePingEvent): | |||
| yield "event: ping\n\n" | |||
| else: | |||
| continue | |||
| def _save_message(self, llm_result: LLMResult) -> None: | |||
| """ | |||
| Save message. | |||
| :param llm_result: llm result | |||
| :return: | |||
| """ | |||
| usage = llm_result.usage | |||
| self._message = db.session.query(Message).filter(Message.id == self._message.id).first() | |||
| self._message.message = self._prompt_messages_to_prompt_for_saving(self._task_state.llm_result.prompt_messages) | |||
| self._message.message_tokens = usage.prompt_tokens | |||
| self._message.message_unit_price = usage.prompt_unit_price | |||
| self._message.message_price_unit = usage.prompt_price_unit | |||
| self._message.answer = PromptTemplateParser.remove_template_variables(llm_result.message.content.strip()) \ | |||
| if llm_result.message.content else '' | |||
| self._message.answer_tokens = usage.completion_tokens | |||
| self._message.answer_unit_price = usage.completion_unit_price | |||
| self._message.answer_price_unit = usage.completion_price_unit | |||
| self._message.provider_response_latency = time.perf_counter() - self._start_at | |||
| self._message.total_price = usage.total_price | |||
| db.session.commit() | |||
| message_was_created.send( | |||
| self._message, | |||
| application_generate_entity=self._application_generate_entity, | |||
| conversation=self._conversation, | |||
| is_first_message=self._application_generate_entity.conversation_id is None, | |||
| extras=self._application_generate_entity.extras | |||
| ) | |||
| def _handle_chunk(self, text: str) -> dict: | |||
| """ | |||
| Handle completed event. | |||
| :param text: text | |||
| :return: | |||
| """ | |||
| response = { | |||
| 'event': 'message', | |||
| 'id': self._message.id, | |||
| 'task_id': self._application_generate_entity.task_id, | |||
| 'message_id': self._message.id, | |||
| 'answer': text, | |||
| 'created_at': int(self._message.created_at.timestamp()) | |||
| } | |||
| if self._conversation.mode == 'chat': | |||
| response['conversation_id'] = self._conversation.id | |||
| return response | |||
| def _handle_error(self, event: QueueErrorEvent) -> Exception: | |||
| """ | |||
| Handle error event. | |||
| :param event: event | |||
| :return: | |||
| """ | |||
| logger.debug("error: %s", event.error) | |||
| e = event.error | |||
| if isinstance(e, InvokeAuthorizationError): | |||
| return InvokeAuthorizationError('Incorrect API key provided') | |||
| elif isinstance(e, InvokeError) or isinstance(e, ValueError): | |||
| return e | |||
| else: | |||
| return Exception(e.description if getattr(e, 'description', None) is not None else str(e)) | |||
| def _yield_response(self, response: dict) -> str: | |||
| """ | |||
| Yield response. | |||
| :param response: response | |||
| :return: | |||
| """ | |||
| return "data: " + json.dumps(response) + "\n\n" | |||
| def _prompt_messages_to_prompt_for_saving(self, prompt_messages: list[PromptMessage]) -> list[dict]: | |||
| """ | |||
| Prompt messages to prompt for saving. | |||
| :param prompt_messages: prompt messages | |||
| :return: | |||
| """ | |||
| prompts = [] | |||
| if self._application_generate_entity.app_orchestration_config_entity.model_config.mode == 'chat': | |||
| for prompt_message in prompt_messages: | |||
| if prompt_message.role == PromptMessageRole.USER: | |||
| role = 'user' | |||
| elif prompt_message.role == PromptMessageRole.ASSISTANT: | |||
| role = 'assistant' | |||
| elif prompt_message.role == PromptMessageRole.SYSTEM: | |||
| role = 'system' | |||
| else: | |||
| continue | |||
| text = '' | |||
| files = [] | |||
| if isinstance(prompt_message.content, list): | |||
| for content in prompt_message.content: | |||
| if content.type == PromptMessageContentType.TEXT: | |||
| content = cast(TextPromptMessageContent, content) | |||
| text += content.data | |||
| else: | |||
| content = cast(ImagePromptMessageContent, content) | |||
| files.append({ | |||
| "type": 'image', | |||
| "data": content.data[:10] + '...[TRUNCATED]...' + content.data[-10:], | |||
| "detail": content.detail.value | |||
| }) | |||
| else: | |||
| text = prompt_message.content | |||
| prompts.append({ | |||
| "role": role, | |||
| "text": text, | |||
| "files": files | |||
| }) | |||
| else: | |||
| prompts.append({ | |||
| "role": 'user', | |||
| "text": prompt_messages[0].content | |||
| }) | |||
| return prompts | |||
| def _init_output_moderation(self) -> Optional[OutputModerationHandler]: | |||
| """ | |||
| Init output moderation. | |||
| :return: | |||
| """ | |||
| app_orchestration_config_entity = self._application_generate_entity.app_orchestration_config_entity | |||
| sensitive_word_avoidance = app_orchestration_config_entity.sensitive_word_avoidance | |||
| if sensitive_word_avoidance: | |||
| return OutputModerationHandler( | |||
| tenant_id=self._application_generate_entity.tenant_id, | |||
| app_id=self._application_generate_entity.app_id, | |||
| rule=ModerationRule( | |||
| type=sensitive_word_avoidance.type, | |||
| config=sensitive_word_avoidance.config | |||
| ), | |||
| on_message_replace_func=self._queue_manager.publish_message_replace | |||
| ) | |||
| @@ -0,0 +1,138 @@ | |||
| import logging | |||
| import threading | |||
| import time | |||
| from typing import Any, Optional, Dict | |||
| from flask import current_app, Flask | |||
| from pydantic import BaseModel | |||
| from core.moderation.base import ModerationAction, ModerationOutputsResult | |||
| from core.moderation.factory import ModerationFactory | |||
| logger = logging.getLogger(__name__) | |||
| class ModerationRule(BaseModel): | |||
| type: str | |||
| config: Dict[str, Any] | |||
| class OutputModerationHandler(BaseModel): | |||
| DEFAULT_BUFFER_SIZE: int = 300 | |||
| tenant_id: str | |||
| app_id: str | |||
| rule: ModerationRule | |||
| on_message_replace_func: Any | |||
| thread: Optional[threading.Thread] = None | |||
| thread_running: bool = True | |||
| buffer: str = '' | |||
| is_final_chunk: bool = False | |||
| final_output: Optional[str] = None | |||
| class Config: | |||
| arbitrary_types_allowed = True | |||
| def should_direct_output(self): | |||
| return self.final_output is not None | |||
| def get_final_output(self): | |||
| return self.final_output | |||
| def append_new_token(self, token: str): | |||
| self.buffer += token | |||
| if not self.thread: | |||
| self.thread = self.start_thread() | |||
| def moderation_completion(self, completion: str, public_event: bool = False) -> str: | |||
| self.buffer = completion | |||
| self.is_final_chunk = True | |||
| result = self.moderation( | |||
| tenant_id=self.tenant_id, | |||
| app_id=self.app_id, | |||
| moderation_buffer=completion | |||
| ) | |||
| if not result or not result.flagged: | |||
| return completion | |||
| if result.action == ModerationAction.DIRECT_OUTPUT: | |||
| final_output = result.preset_response | |||
| else: | |||
| final_output = result.text | |||
| if public_event: | |||
| self.on_message_replace_func(final_output) | |||
| return final_output | |||
| def start_thread(self) -> threading.Thread: | |||
| buffer_size = int(current_app.config.get('MODERATION_BUFFER_SIZE', self.DEFAULT_BUFFER_SIZE)) | |||
| thread = threading.Thread(target=self.worker, kwargs={ | |||
| 'flask_app': current_app._get_current_object(), | |||
| 'buffer_size': buffer_size if buffer_size > 0 else self.DEFAULT_BUFFER_SIZE | |||
| }) | |||
| thread.start() | |||
| return thread | |||
| def stop_thread(self): | |||
| if self.thread and self.thread.is_alive(): | |||
| self.thread_running = False | |||
| def worker(self, flask_app: Flask, buffer_size: int): | |||
| with flask_app.app_context(): | |||
| current_length = 0 | |||
| while self.thread_running: | |||
| moderation_buffer = self.buffer | |||
| buffer_length = len(moderation_buffer) | |||
| if not self.is_final_chunk: | |||
| chunk_length = buffer_length - current_length | |||
| if 0 <= chunk_length < buffer_size: | |||
| time.sleep(1) | |||
| continue | |||
| current_length = buffer_length | |||
| result = self.moderation( | |||
| tenant_id=self.tenant_id, | |||
| app_id=self.app_id, | |||
| moderation_buffer=moderation_buffer | |||
| ) | |||
| if not result or not result.flagged: | |||
| continue | |||
| if result.action == ModerationAction.DIRECT_OUTPUT: | |||
| final_output = result.preset_response | |||
| self.final_output = final_output | |||
| else: | |||
| final_output = result.text + self.buffer[len(moderation_buffer):] | |||
| # trigger replace event | |||
| if self.thread_running: | |||
| self.on_message_replace_func(final_output) | |||
| if result.action == ModerationAction.DIRECT_OUTPUT: | |||
| break | |||
| def moderation(self, tenant_id: str, app_id: str, moderation_buffer: str) -> Optional[ModerationOutputsResult]: | |||
| try: | |||
| moderation_factory = ModerationFactory( | |||
| name=self.rule.type, | |||
| app_id=app_id, | |||
| tenant_id=tenant_id, | |||
| config=self.rule.config | |||
| ) | |||
| result: ModerationOutputsResult = moderation_factory.moderation_for_outputs(moderation_buffer) | |||
| return result | |||
| except Exception as e: | |||
| logger.error("Moderation Output error: %s", e) | |||
| return None | |||
| @@ -0,0 +1,655 @@ | |||
| import json | |||
| import logging | |||
| import threading | |||
| import uuid | |||
| from typing import cast, Optional, Any, Union, Generator, Tuple | |||
| from flask import Flask, current_app | |||
| from pydantic import ValidationError | |||
| from core.app_runner.agent_app_runner import AgentApplicationRunner | |||
| from core.app_runner.basic_app_runner import BasicApplicationRunner | |||
| from core.app_runner.generate_task_pipeline import GenerateTaskPipeline | |||
| from core.entities.application_entities import ApplicationGenerateEntity, AppOrchestrationConfigEntity, \ | |||
| ModelConfigEntity, PromptTemplateEntity, AdvancedChatPromptTemplateEntity, \ | |||
| AdvancedCompletionPromptTemplateEntity, ExternalDataVariableEntity, DatasetEntity, DatasetRetrieveConfigEntity, \ | |||
| AgentEntity, AgentToolEntity, FileUploadEntity, SensitiveWordAvoidanceEntity, InvokeFrom | |||
| from core.entities.model_entities import ModelStatus | |||
| from core.file.file_obj import FileObj | |||
| from core.errors.error import QuotaExceededError, ProviderTokenNotInitError, ModelCurrentlyNotSupportError | |||
| from core.model_runtime.entities.message_entities import PromptMessageRole | |||
| from core.model_runtime.entities.model_entities import ModelType | |||
| from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError | |||
| from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel | |||
| from core.prompt.prompt_template import PromptTemplateParser | |||
| from core.provider_manager import ProviderManager | |||
| from core.application_queue_manager import ApplicationQueueManager, ConversationTaskStoppedException | |||
| from extensions.ext_database import db | |||
| from models.account import Account | |||
| from models.model import EndUser, Conversation, Message, MessageFile, App | |||
| logger = logging.getLogger(__name__) | |||
| class ApplicationManager: | |||
| """ | |||
| This class is responsible for managing application | |||
| """ | |||
| def generate(self, tenant_id: str, | |||
| app_id: str, | |||
| app_model_config_id: str, | |||
| app_model_config_dict: dict, | |||
| app_model_config_override: bool, | |||
| user: Union[Account, EndUser], | |||
| invoke_from: InvokeFrom, | |||
| inputs: dict[str, str], | |||
| query: Optional[str] = None, | |||
| files: Optional[list[FileObj]] = None, | |||
| conversation: Optional[Conversation] = None, | |||
| stream: bool = False, | |||
| extras: Optional[dict[str, Any]] = None) \ | |||
| -> Union[dict, Generator]: | |||
| """ | |||
| Generate App response. | |||
| :param tenant_id: workspace ID | |||
| :param app_id: app ID | |||
| :param app_model_config_id: app model config id | |||
| :param app_model_config_dict: app model config dict | |||
| :param app_model_config_override: app model config override | |||
| :param user: account or end user | |||
| :param invoke_from: invoke from source | |||
| :param inputs: inputs | |||
| :param query: query | |||
| :param files: file obj list | |||
| :param conversation: conversation | |||
| :param stream: is stream | |||
| :param extras: extras | |||
| """ | |||
| # init task id | |||
| task_id = str(uuid.uuid4()) | |||
| # init application generate entity | |||
| application_generate_entity = ApplicationGenerateEntity( | |||
| task_id=task_id, | |||
| tenant_id=tenant_id, | |||
| app_id=app_id, | |||
| app_model_config_id=app_model_config_id, | |||
| app_model_config_dict=app_model_config_dict, | |||
| app_orchestration_config_entity=self._convert_from_app_model_config_dict( | |||
| tenant_id=tenant_id, | |||
| app_model_config_dict=app_model_config_dict | |||
| ), | |||
| app_model_config_override=app_model_config_override, | |||
| conversation_id=conversation.id if conversation else None, | |||
| inputs=conversation.inputs if conversation else inputs, | |||
| query=query.replace('\x00', '') if query else None, | |||
| files=files if files else [], | |||
| user_id=user.id, | |||
| stream=stream, | |||
| invoke_from=invoke_from, | |||
| extras=extras | |||
| ) | |||
| # init generate records | |||
| ( | |||
| conversation, | |||
| message | |||
| ) = self._init_generate_records(application_generate_entity) | |||
| # init queue manager | |||
| queue_manager = ApplicationQueueManager( | |||
| task_id=application_generate_entity.task_id, | |||
| user_id=application_generate_entity.user_id, | |||
| invoke_from=application_generate_entity.invoke_from, | |||
| conversation_id=conversation.id, | |||
| app_mode=conversation.mode, | |||
| message_id=message.id | |||
| ) | |||
| # new thread | |||
| worker_thread = threading.Thread(target=self._generate_worker, kwargs={ | |||
| 'flask_app': current_app._get_current_object(), | |||
| 'application_generate_entity': application_generate_entity, | |||
| 'queue_manager': queue_manager, | |||
| 'conversation_id': conversation.id, | |||
| 'message_id': message.id, | |||
| }) | |||
| worker_thread.start() | |||
| # return response or stream generator | |||
| return self._handle_response( | |||
| application_generate_entity=application_generate_entity, | |||
| queue_manager=queue_manager, | |||
| conversation=conversation, | |||
| message=message, | |||
| stream=stream | |||
| ) | |||
| def _generate_worker(self, flask_app: Flask, | |||
| application_generate_entity: ApplicationGenerateEntity, | |||
| queue_manager: ApplicationQueueManager, | |||
| conversation_id: str, | |||
| message_id: str) -> None: | |||
| """ | |||
| Generate worker in a new thread. | |||
| :param flask_app: Flask app | |||
| :param application_generate_entity: application generate entity | |||
| :param queue_manager: queue manager | |||
| :param conversation_id: conversation ID | |||
| :param message_id: message ID | |||
| :return: | |||
| """ | |||
| with flask_app.app_context(): | |||
| try: | |||
| # get conversation and message | |||
| conversation = self._get_conversation(conversation_id) | |||
| message = self._get_message(message_id) | |||
| if application_generate_entity.app_orchestration_config_entity.agent: | |||
| # agent app | |||
| runner = AgentApplicationRunner() | |||
| runner.run( | |||
| application_generate_entity=application_generate_entity, | |||
| queue_manager=queue_manager, | |||
| conversation=conversation, | |||
| message=message | |||
| ) | |||
| else: | |||
| # basic app | |||
| runner = BasicApplicationRunner() | |||
| runner.run( | |||
| application_generate_entity=application_generate_entity, | |||
| queue_manager=queue_manager, | |||
| conversation=conversation, | |||
| message=message | |||
| ) | |||
| except ConversationTaskStoppedException: | |||
| pass | |||
| except InvokeAuthorizationError: | |||
| queue_manager.publish_error(InvokeAuthorizationError('Incorrect API key provided')) | |||
| except ValidationError as e: | |||
| logger.exception("Validation Error when generating") | |||
| queue_manager.publish_error(e) | |||
| except (ValueError, InvokeError) as e: | |||
| queue_manager.publish_error(e) | |||
| except Exception as e: | |||
| logger.exception("Unknown Error when generating") | |||
| queue_manager.publish_error(e) | |||
| finally: | |||
| db.session.remove() | |||
| def _handle_response(self, application_generate_entity: ApplicationGenerateEntity, | |||
| queue_manager: ApplicationQueueManager, | |||
| conversation: Conversation, | |||
| message: Message, | |||
| stream: bool = False) -> Union[dict, Generator]: | |||
| """ | |||
| Handle response. | |||
| :param application_generate_entity: application generate entity | |||
| :param queue_manager: queue manager | |||
| :param conversation: conversation | |||
| :param message: message | |||
| :param stream: is stream | |||
| :return: | |||
| """ | |||
| # init generate task pipeline | |||
| generate_task_pipeline = GenerateTaskPipeline( | |||
| application_generate_entity=application_generate_entity, | |||
| queue_manager=queue_manager, | |||
| conversation=conversation, | |||
| message=message | |||
| ) | |||
| try: | |||
| return generate_task_pipeline.process(stream=stream) | |||
| except ValueError as e: | |||
| if e.args[0] == "I/O operation on closed file.": # ignore this error | |||
| raise ConversationTaskStoppedException() | |||
| else: | |||
| logger.exception(e) | |||
| raise e | |||
| finally: | |||
| db.session.remove() | |||
| def _convert_from_app_model_config_dict(self, tenant_id: str, app_model_config_dict: dict) \ | |||
| -> AppOrchestrationConfigEntity: | |||
| """ | |||
| Convert app model config dict to entity. | |||
| :param tenant_id: tenant ID | |||
| :param app_model_config_dict: app model config dict | |||
| :raises ProviderTokenNotInitError: provider token not init error | |||
| :return: app orchestration config entity | |||
| """ | |||
| properties = {} | |||
| copy_app_model_config_dict = app_model_config_dict.copy() | |||
| provider_manager = ProviderManager() | |||
| provider_model_bundle = provider_manager.get_provider_model_bundle( | |||
| tenant_id=tenant_id, | |||
| provider=copy_app_model_config_dict['model']['provider'], | |||
| model_type=ModelType.LLM | |||
| ) | |||
| provider_name = provider_model_bundle.configuration.provider.provider | |||
| model_name = copy_app_model_config_dict['model']['name'] | |||
| model_type_instance = provider_model_bundle.model_type_instance | |||
| model_type_instance = cast(LargeLanguageModel, model_type_instance) | |||
| # check model credentials | |||
| model_credentials = provider_model_bundle.configuration.get_current_credentials( | |||
| model_type=ModelType.LLM, | |||
| model=copy_app_model_config_dict['model']['name'] | |||
| ) | |||
| if model_credentials is None: | |||
| raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.") | |||
| # check model | |||
| provider_model = provider_model_bundle.configuration.get_provider_model( | |||
| model=copy_app_model_config_dict['model']['name'], | |||
| model_type=ModelType.LLM | |||
| ) | |||
| if provider_model is None: | |||
| model_name = copy_app_model_config_dict['model']['name'] | |||
| raise ValueError(f"Model {model_name} not exist.") | |||
| if provider_model.status == ModelStatus.NO_CONFIGURE: | |||
| raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.") | |||
| elif provider_model.status == ModelStatus.NO_PERMISSION: | |||
| raise ModelCurrentlyNotSupportError(f"Dify Hosted OpenAI {model_name} currently not support.") | |||
| elif provider_model.status == ModelStatus.QUOTA_EXCEEDED: | |||
| raise QuotaExceededError(f"Model provider {provider_name} quota exceeded.") | |||
| # model config | |||
| completion_params = copy_app_model_config_dict['model'].get('completion_params') | |||
| stop = [] | |||
| if 'stop' in completion_params: | |||
| stop = completion_params['stop'] | |||
| del completion_params['stop'] | |||
| # get model mode | |||
| model_mode = copy_app_model_config_dict['model'].get('mode') | |||
| if not model_mode: | |||
| mode_enum = model_type_instance.get_model_mode( | |||
| model=copy_app_model_config_dict['model']['name'], | |||
| credentials=model_credentials | |||
| ) | |||
| model_mode = mode_enum.value | |||
| model_schema = model_type_instance.get_model_schema( | |||
| copy_app_model_config_dict['model']['name'], | |||
| model_credentials | |||
| ) | |||
| if not model_schema: | |||
| raise ValueError(f"Model {model_name} not exist.") | |||
| properties['model_config'] = ModelConfigEntity( | |||
| provider=copy_app_model_config_dict['model']['provider'], | |||
| model=copy_app_model_config_dict['model']['name'], | |||
| model_schema=model_schema, | |||
| mode=model_mode, | |||
| provider_model_bundle=provider_model_bundle, | |||
| credentials=model_credentials, | |||
| parameters=completion_params, | |||
| stop=stop, | |||
| ) | |||
| # prompt template | |||
| prompt_type = PromptTemplateEntity.PromptType.value_of(copy_app_model_config_dict['prompt_type']) | |||
| if prompt_type == PromptTemplateEntity.PromptType.SIMPLE: | |||
| simple_prompt_template = copy_app_model_config_dict.get("pre_prompt", "") | |||
| properties['prompt_template'] = PromptTemplateEntity( | |||
| prompt_type=prompt_type, | |||
| simple_prompt_template=simple_prompt_template | |||
| ) | |||
| else: | |||
| advanced_chat_prompt_template = None | |||
| chat_prompt_config = copy_app_model_config_dict.get("chat_prompt_config", {}) | |||
| if chat_prompt_config: | |||
| chat_prompt_messages = [] | |||
| for message in chat_prompt_config.get("prompt", []): | |||
| chat_prompt_messages.append({ | |||
| "text": message["text"], | |||
| "role": PromptMessageRole.value_of(message["role"]) | |||
| }) | |||
| advanced_chat_prompt_template = AdvancedChatPromptTemplateEntity( | |||
| messages=chat_prompt_messages | |||
| ) | |||
| advanced_completion_prompt_template = None | |||
| completion_prompt_config = copy_app_model_config_dict.get("completion_prompt_config", {}) | |||
| if completion_prompt_config: | |||
| completion_prompt_template_params = { | |||
| 'prompt': completion_prompt_config['prompt']['text'], | |||
| } | |||
| if 'conversation_histories_role' in completion_prompt_config: | |||
| completion_prompt_template_params['role_prefix'] = { | |||
| 'user': completion_prompt_config['conversation_histories_role']['user_prefix'], | |||
| 'assistant': completion_prompt_config['conversation_histories_role']['assistant_prefix'] | |||
| } | |||
| advanced_completion_prompt_template = AdvancedCompletionPromptTemplateEntity( | |||
| **completion_prompt_template_params | |||
| ) | |||
| properties['prompt_template'] = PromptTemplateEntity( | |||
| prompt_type=prompt_type, | |||
| advanced_chat_prompt_template=advanced_chat_prompt_template, | |||
| advanced_completion_prompt_template=advanced_completion_prompt_template | |||
| ) | |||
| # external data variables | |||
| properties['external_data_variables'] = [] | |||
| external_data_tools = copy_app_model_config_dict.get('external_data_tools', []) | |||
| for external_data_tool in external_data_tools: | |||
| if 'enabled' not in external_data_tool or not external_data_tool['enabled']: | |||
| continue | |||
| properties['external_data_variables'].append( | |||
| ExternalDataVariableEntity( | |||
| variable=external_data_tool['variable'], | |||
| type=external_data_tool['type'], | |||
| config=external_data_tool['config'] | |||
| ) | |||
| ) | |||
| # show retrieve source | |||
| show_retrieve_source = False | |||
| retriever_resource_dict = copy_app_model_config_dict.get('retriever_resource') | |||
| if retriever_resource_dict: | |||
| if 'enabled' in retriever_resource_dict and retriever_resource_dict['enabled']: | |||
| show_retrieve_source = True | |||
| properties['show_retrieve_source'] = show_retrieve_source | |||
| if 'agent_mode' in copy_app_model_config_dict and copy_app_model_config_dict['agent_mode'] \ | |||
| and 'enabled' in copy_app_model_config_dict['agent_mode'] and copy_app_model_config_dict['agent_mode'][ | |||
| 'enabled']: | |||
| agent_dict = copy_app_model_config_dict.get('agent_mode') | |||
| if agent_dict['strategy'] in ['router', 'react_router']: | |||
| dataset_ids = [] | |||
| for tool in agent_dict.get('tools', []): | |||
| key = list(tool.keys())[0] | |||
| if key != 'dataset': | |||
| continue | |||
| tool_item = tool[key] | |||
| if "enabled" not in tool_item or not tool_item["enabled"]: | |||
| continue | |||
| dataset_id = tool_item['id'] | |||
| dataset_ids.append(dataset_id) | |||
| dataset_configs = copy_app_model_config_dict.get('dataset_configs', {'retrieval_model': 'single'}) | |||
| query_variable = copy_app_model_config_dict.get('dataset_query_variable') | |||
| if dataset_configs['retrieval_model'] == 'single': | |||
| properties['dataset'] = DatasetEntity( | |||
| dataset_ids=dataset_ids, | |||
| retrieve_config=DatasetRetrieveConfigEntity( | |||
| query_variable=query_variable, | |||
| retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of( | |||
| dataset_configs['retrieval_model'] | |||
| ), | |||
| single_strategy=agent_dict['strategy'] | |||
| ) | |||
| ) | |||
| else: | |||
| properties['dataset'] = DatasetEntity( | |||
| dataset_ids=dataset_ids, | |||
| retrieve_config=DatasetRetrieveConfigEntity( | |||
| query_variable=query_variable, | |||
| retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of( | |||
| dataset_configs['retrieval_model'] | |||
| ), | |||
| top_k=dataset_configs.get('top_k'), | |||
| score_threshold=dataset_configs.get('score_threshold'), | |||
| reranking_model=dataset_configs.get('reranking_model') | |||
| ) | |||
| ) | |||
| else: | |||
| if agent_dict['strategy'] == 'react': | |||
| strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT | |||
| else: | |||
| strategy = AgentEntity.Strategy.FUNCTION_CALLING | |||
| agent_tools = [] | |||
| for tool in agent_dict.get('tools', []): | |||
| key = list(tool.keys())[0] | |||
| tool_item = tool[key] | |||
| agent_tool_properties = { | |||
| "tool_id": key | |||
| } | |||
| if "enabled" not in tool_item or not tool_item["enabled"]: | |||
| continue | |||
| agent_tool_properties["config"] = tool_item | |||
| agent_tools.append(AgentToolEntity(**agent_tool_properties)) | |||
| properties['agent'] = AgentEntity( | |||
| provider=properties['model_config'].provider, | |||
| model=properties['model_config'].model, | |||
| strategy=strategy, | |||
| tools=agent_tools | |||
| ) | |||
| # file upload | |||
| file_upload_dict = copy_app_model_config_dict.get('file_upload') | |||
| if file_upload_dict: | |||
| if 'image' in file_upload_dict and file_upload_dict['image']: | |||
| if 'enabled' in file_upload_dict['image'] and file_upload_dict['image']['enabled']: | |||
| properties['file_upload'] = FileUploadEntity( | |||
| image_config={ | |||
| 'number_limits': file_upload_dict['image']['number_limits'], | |||
| 'detail': file_upload_dict['image']['detail'], | |||
| 'transfer_methods': file_upload_dict['image']['transfer_methods'] | |||
| } | |||
| ) | |||
| # opening statement | |||
| properties['opening_statement'] = copy_app_model_config_dict.get('opening_statement') | |||
| # suggested questions after answer | |||
| suggested_questions_after_answer_dict = copy_app_model_config_dict.get('suggested_questions_after_answer') | |||
| if suggested_questions_after_answer_dict: | |||
| if 'enabled' in suggested_questions_after_answer_dict and suggested_questions_after_answer_dict['enabled']: | |||
| properties['suggested_questions_after_answer'] = True | |||
| # more like this | |||
| more_like_this_dict = copy_app_model_config_dict.get('more_like_this') | |||
| if more_like_this_dict: | |||
| if 'enabled' in more_like_this_dict and more_like_this_dict['enabled']: | |||
| properties['more_like_this'] = copy_app_model_config_dict.get('opening_statement') | |||
| # speech to text | |||
| speech_to_text_dict = copy_app_model_config_dict.get('speech_to_text') | |||
| if speech_to_text_dict: | |||
| if 'enabled' in speech_to_text_dict and speech_to_text_dict['enabled']: | |||
| properties['speech_to_text'] = True | |||
| # sensitive word avoidance | |||
| sensitive_word_avoidance_dict = copy_app_model_config_dict.get('sensitive_word_avoidance') | |||
| if sensitive_word_avoidance_dict: | |||
| if 'enabled' in sensitive_word_avoidance_dict and sensitive_word_avoidance_dict['enabled']: | |||
| properties['sensitive_word_avoidance'] = SensitiveWordAvoidanceEntity( | |||
| type=sensitive_word_avoidance_dict.get('type'), | |||
| config=sensitive_word_avoidance_dict.get('config'), | |||
| ) | |||
| return AppOrchestrationConfigEntity(**properties) | |||
| def _init_generate_records(self, application_generate_entity: ApplicationGenerateEntity) \ | |||
| -> Tuple[Conversation, Message]: | |||
| """ | |||
| Initialize generate records | |||
| :param application_generate_entity: application generate entity | |||
| :return: | |||
| """ | |||
| app_orchestration_config_entity = application_generate_entity.app_orchestration_config_entity | |||
| model_type_instance = app_orchestration_config_entity.model_config.provider_model_bundle.model_type_instance | |||
| model_type_instance = cast(LargeLanguageModel, model_type_instance) | |||
| model_schema = model_type_instance.get_model_schema( | |||
| model=app_orchestration_config_entity.model_config.model, | |||
| credentials=app_orchestration_config_entity.model_config.credentials | |||
| ) | |||
| app_record = (db.session.query(App) | |||
| .filter(App.id == application_generate_entity.app_id).first()) | |||
| app_mode = app_record.mode | |||
| # get from source | |||
| end_user_id = None | |||
| account_id = None | |||
| if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]: | |||
| from_source = 'api' | |||
| end_user_id = application_generate_entity.user_id | |||
| else: | |||
| from_source = 'console' | |||
| account_id = application_generate_entity.user_id | |||
| override_model_configs = None | |||
| if application_generate_entity.app_model_config_override: | |||
| override_model_configs = application_generate_entity.app_model_config_dict | |||
| introduction = '' | |||
| if app_mode == 'chat': | |||
| # get conversation introduction | |||
| introduction = self._get_conversation_introduction(application_generate_entity) | |||
| if not application_generate_entity.conversation_id: | |||
| conversation = Conversation( | |||
| app_id=app_record.id, | |||
| app_model_config_id=application_generate_entity.app_model_config_id, | |||
| model_provider=app_orchestration_config_entity.model_config.provider, | |||
| model_id=app_orchestration_config_entity.model_config.model, | |||
| override_model_configs=json.dumps(override_model_configs) if override_model_configs else None, | |||
| mode=app_mode, | |||
| name='New conversation', | |||
| inputs=application_generate_entity.inputs, | |||
| introduction=introduction, | |||
| system_instruction="", | |||
| system_instruction_tokens=0, | |||
| status='normal', | |||
| from_source=from_source, | |||
| from_end_user_id=end_user_id, | |||
| from_account_id=account_id, | |||
| ) | |||
| db.session.add(conversation) | |||
| db.session.commit() | |||
| else: | |||
| conversation = ( | |||
| db.session.query(Conversation) | |||
| .filter( | |||
| Conversation.id == application_generate_entity.conversation_id, | |||
| Conversation.app_id == app_record.id | |||
| ).first() | |||
| ) | |||
| currency = model_schema.pricing.currency if model_schema.pricing else 'USD' | |||
| message = Message( | |||
| app_id=app_record.id, | |||
| model_provider=app_orchestration_config_entity.model_config.provider, | |||
| model_id=app_orchestration_config_entity.model_config.model, | |||
| override_model_configs=json.dumps(override_model_configs) if override_model_configs else None, | |||
| conversation_id=conversation.id, | |||
| inputs=application_generate_entity.inputs, | |||
| query=application_generate_entity.query or "", | |||
| message="", | |||
| message_tokens=0, | |||
| message_unit_price=0, | |||
| message_price_unit=0, | |||
| answer="", | |||
| answer_tokens=0, | |||
| answer_unit_price=0, | |||
| answer_price_unit=0, | |||
| provider_response_latency=0, | |||
| total_price=0, | |||
| currency=currency, | |||
| from_source=from_source, | |||
| from_end_user_id=end_user_id, | |||
| from_account_id=account_id, | |||
| agent_based=app_orchestration_config_entity.agent is not None | |||
| ) | |||
| db.session.add(message) | |||
| db.session.commit() | |||
| for file in application_generate_entity.files: | |||
| message_file = MessageFile( | |||
| message_id=message.id, | |||
| type=file.type.value, | |||
| transfer_method=file.transfer_method.value, | |||
| url=file.url, | |||
| upload_file_id=file.upload_file_id, | |||
| created_by_role=('account' if account_id else 'end_user'), | |||
| created_by=account_id or end_user_id, | |||
| ) | |||
| db.session.add(message_file) | |||
| db.session.commit() | |||
| return conversation, message | |||
| def _get_conversation_introduction(self, application_generate_entity: ApplicationGenerateEntity) -> str: | |||
| """ | |||
| Get conversation introduction | |||
| :param application_generate_entity: application generate entity | |||
| :return: conversation introduction | |||
| """ | |||
| app_orchestration_config_entity = application_generate_entity.app_orchestration_config_entity | |||
| introduction = app_orchestration_config_entity.opening_statement | |||
| if introduction: | |||
| try: | |||
| inputs = application_generate_entity.inputs | |||
| prompt_template = PromptTemplateParser(template=introduction) | |||
| prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} | |||
| introduction = prompt_template.format(prompt_inputs) | |||
| except KeyError: | |||
| pass | |||
| return introduction | |||
| def _get_conversation(self, conversation_id: str) -> Conversation: | |||
| """ | |||
| Get conversation by conversation id | |||
| :param conversation_id: conversation id | |||
| :return: conversation | |||
| """ | |||
| conversation = ( | |||
| db.session.query(Conversation) | |||
| .filter(Conversation.id == conversation_id) | |||
| .first() | |||
| ) | |||
| return conversation | |||
| def _get_message(self, message_id: str) -> Message: | |||
| """ | |||
| Get message by message id | |||
| :param message_id: message id | |||
| :return: message | |||
| """ | |||
| message = ( | |||
| db.session.query(Message) | |||
| .filter(Message.id == message_id) | |||
| .first() | |||
| ) | |||
| return message | |||
| @@ -0,0 +1,228 @@ | |||
| import queue | |||
| import time | |||
| from typing import Generator, Any | |||
| from sqlalchemy.orm import DeclarativeMeta | |||
| from core.entities.application_entities import InvokeFrom | |||
| from core.entities.queue_entities import QueueStopEvent, AppQueueEvent, QueuePingEvent, QueueErrorEvent, \ | |||
| QueueAgentThoughtEvent, QueueMessageEndEvent, QueueRetrieverResourcesEvent, QueueMessageReplaceEvent, \ | |||
| QueueMessageEvent, QueueMessage, AnnotationReplyEvent | |||
| from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk | |||
| from extensions.ext_redis import redis_client | |||
| from models.model import MessageAgentThought | |||
| class ApplicationQueueManager: | |||
| def __init__(self, task_id: str, | |||
| user_id: str, | |||
| invoke_from: InvokeFrom, | |||
| conversation_id: str, | |||
| app_mode: str, | |||
| message_id: str) -> None: | |||
| if not user_id: | |||
| raise ValueError("user is required") | |||
| self._task_id = task_id | |||
| self._user_id = user_id | |||
| self._invoke_from = invoke_from | |||
| self._conversation_id = str(conversation_id) | |||
| self._app_mode = app_mode | |||
| self._message_id = str(message_id) | |||
| user_prefix = 'account' if self._invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end-user' | |||
| redis_client.setex(ApplicationQueueManager._generate_task_belong_cache_key(self._task_id), 1800, f"{user_prefix}-{self._user_id}") | |||
| q = queue.Queue() | |||
| self._q = q | |||
| def listen(self) -> Generator: | |||
| """ | |||
| Listen to queue | |||
| :return: | |||
| """ | |||
| # wait for 10 minutes to stop listen | |||
| listen_timeout = 600 | |||
| start_time = time.time() | |||
| last_ping_time = 0 | |||
| while True: | |||
| try: | |||
| message = self._q.get(timeout=1) | |||
| if message is None: | |||
| break | |||
| yield message | |||
| except queue.Empty: | |||
| continue | |||
| finally: | |||
| elapsed_time = time.time() - start_time | |||
| if elapsed_time >= listen_timeout or self._is_stopped(): | |||
| # publish two messages to make sure the client can receive the stop signal | |||
| # and stop listening after the stop signal processed | |||
| self.publish(QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL)) | |||
| self.stop_listen() | |||
| if elapsed_time // 10 > last_ping_time: | |||
| self.publish(QueuePingEvent()) | |||
| last_ping_time = elapsed_time // 10 | |||
| def stop_listen(self) -> None: | |||
| """ | |||
| Stop listen to queue | |||
| :return: | |||
| """ | |||
| self._q.put(None) | |||
| def publish_chunk_message(self, chunk: LLMResultChunk) -> None: | |||
| """ | |||
| Publish chunk message to channel | |||
| :param chunk: chunk | |||
| :return: | |||
| """ | |||
| self.publish(QueueMessageEvent( | |||
| chunk=chunk | |||
| )) | |||
| def publish_message_replace(self, text: str) -> None: | |||
| """ | |||
| Publish message replace | |||
| :param text: text | |||
| :return: | |||
| """ | |||
| self.publish(QueueMessageReplaceEvent( | |||
| text=text | |||
| )) | |||
| def publish_retriever_resources(self, retriever_resources: list[dict]) -> None: | |||
| """ | |||
| Publish retriever resources | |||
| :return: | |||
| """ | |||
| self.publish(QueueRetrieverResourcesEvent(retriever_resources=retriever_resources)) | |||
| def publish_annotation_reply(self, message_annotation_id: str) -> None: | |||
| """ | |||
| Publish annotation reply | |||
| :param message_annotation_id: message annotation id | |||
| :return: | |||
| """ | |||
| self.publish(AnnotationReplyEvent(message_annotation_id=message_annotation_id)) | |||
| def publish_message_end(self, llm_result: LLMResult) -> None: | |||
| """ | |||
| Publish message end | |||
| :param llm_result: llm result | |||
| :return: | |||
| """ | |||
| self.publish(QueueMessageEndEvent(llm_result=llm_result)) | |||
| self.stop_listen() | |||
| def publish_agent_thought(self, message_agent_thought: MessageAgentThought) -> None: | |||
| """ | |||
| Publish agent thought | |||
| :param message_agent_thought: message agent thought | |||
| :return: | |||
| """ | |||
| self.publish(QueueAgentThoughtEvent( | |||
| agent_thought_id=message_agent_thought.id | |||
| )) | |||
| def publish_error(self, e) -> None: | |||
| """ | |||
| Publish error | |||
| :param e: error | |||
| :return: | |||
| """ | |||
| self.publish(QueueErrorEvent( | |||
| error=e | |||
| )) | |||
| self.stop_listen() | |||
| def publish(self, event: AppQueueEvent) -> None: | |||
| """ | |||
| Publish event to queue | |||
| :param event: | |||
| :return: | |||
| """ | |||
| self._check_for_sqlalchemy_models(event.dict()) | |||
| message = QueueMessage( | |||
| task_id=self._task_id, | |||
| message_id=self._message_id, | |||
| conversation_id=self._conversation_id, | |||
| app_mode=self._app_mode, | |||
| event=event | |||
| ) | |||
| self._q.put(message) | |||
| if isinstance(event, QueueStopEvent): | |||
| self.stop_listen() | |||
| @classmethod | |||
| def set_stop_flag(cls, task_id: str, invoke_from: InvokeFrom, user_id: str) -> None: | |||
| """ | |||
| Set task stop flag | |||
| :return: | |||
| """ | |||
| result = redis_client.get(cls._generate_task_belong_cache_key(task_id)) | |||
| if result is None: | |||
| return | |||
| user_prefix = 'account' if invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end-user' | |||
| if result != f"{user_prefix}-{user_id}": | |||
| return | |||
| stopped_cache_key = cls._generate_stopped_cache_key(task_id) | |||
| redis_client.setex(stopped_cache_key, 600, 1) | |||
| def _is_stopped(self) -> bool: | |||
| """ | |||
| Check if task is stopped | |||
| :return: | |||
| """ | |||
| stopped_cache_key = ApplicationQueueManager._generate_stopped_cache_key(self._task_id) | |||
| result = redis_client.get(stopped_cache_key) | |||
| if result is not None: | |||
| redis_client.delete(stopped_cache_key) | |||
| return True | |||
| return False | |||
| @classmethod | |||
| def _generate_task_belong_cache_key(cls, task_id: str) -> str: | |||
| """ | |||
| Generate task belong cache key | |||
| :param task_id: task id | |||
| :return: | |||
| """ | |||
| return f"generate_task_belong:{task_id}" | |||
| @classmethod | |||
| def _generate_stopped_cache_key(cls, task_id: str) -> str: | |||
| """ | |||
| Generate stopped cache key | |||
| :param task_id: task id | |||
| :return: | |||
| """ | |||
| return f"generate_task_stopped:{task_id}" | |||
| def _check_for_sqlalchemy_models(self, data: Any): | |||
| # from entity to dict or list | |||
| if isinstance(data, dict): | |||
| for key, value in data.items(): | |||
| self._check_for_sqlalchemy_models(value) | |||
| elif isinstance(data, list): | |||
| for item in data: | |||
| self._check_for_sqlalchemy_models(item) | |||
| else: | |||
| if isinstance(data, DeclarativeMeta) or hasattr(data, '_sa_instance_state'): | |||
| raise TypeError("Critical Error: Passing SQLAlchemy Model instances " | |||
| "that cause thread safety issues is not allowed.") | |||
| class ConversationTaskStoppedException(Exception): | |||
| pass | |||
| @@ -2,30 +2,40 @@ import json | |||
| import logging | |||
| import time | |||
| from typing import Any, Dict, List, Union, Optional | |||
| from typing import Any, Dict, List, Union, Optional, cast | |||
| from langchain.agents import openai_functions_agent, openai_functions_multi_agent | |||
| from langchain.callbacks.base import BaseCallbackHandler | |||
| from langchain.schema import AgentAction, AgentFinish, LLMResult, ChatGeneration, BaseMessage | |||
| from core.application_queue_manager import ApplicationQueueManager | |||
| from core.callback_handler.entity.agent_loop import AgentLoop | |||
| from core.conversation_message_task import ConversationMessageTask | |||
| from core.model_providers.models.entity.message import PromptMessage | |||
| from core.model_providers.models.llm.base import BaseLLM | |||
| from core.entities.application_entities import ModelConfigEntity | |||
| from core.model_runtime.entities.llm_entities import LLMResult as RuntimeLLMResult | |||
| from core.model_runtime.entities.message_entities import UserPromptMessage, AssistantPromptMessage, PromptMessage | |||
| from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel | |||
| from extensions.ext_database import db | |||
| from models.model import MessageChain, MessageAgentThought, Message | |||
| class AgentLoopGatherCallbackHandler(BaseCallbackHandler): | |||
| """Callback Handler that prints to std out.""" | |||
| raise_error: bool = True | |||
| def __init__(self, model_instance: BaseLLM, conversation_message_task: ConversationMessageTask) -> None: | |||
| def __init__(self, model_config: ModelConfigEntity, | |||
| queue_manager: ApplicationQueueManager, | |||
| message: Message, | |||
| message_chain: MessageChain) -> None: | |||
| """Initialize callback handler.""" | |||
| self.model_instance = model_instance | |||
| self.conversation_message_task = conversation_message_task | |||
| self.model_config = model_config | |||
| self.queue_manager = queue_manager | |||
| self.message = message | |||
| self.message_chain = message_chain | |||
| model_type_instance = self.model_config.provider_model_bundle.model_type_instance | |||
| self.model_type_instance = cast(LargeLanguageModel, model_type_instance) | |||
| self._agent_loops = [] | |||
| self._current_loop = None | |||
| self._message_agent_thought = None | |||
| self.current_chain = None | |||
| @property | |||
| def agent_loops(self) -> List[AgentLoop]: | |||
| @@ -46,66 +56,61 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler): | |||
| """Whether to ignore chain callbacks.""" | |||
| return True | |||
| def on_chat_model_start( | |||
| self, | |||
| serialized: Dict[str, Any], | |||
| messages: List[List[BaseMessage]], | |||
| **kwargs: Any | |||
| ) -> Any: | |||
| if not self._current_loop: | |||
| # Agent start with a LLM query | |||
| self._current_loop = AgentLoop( | |||
| position=len(self._agent_loops) + 1, | |||
| prompt="\n".join([message.content for message in messages[0]]), | |||
| status='llm_started', | |||
| started_at=time.perf_counter() | |||
| ) | |||
| def on_llm_start( | |||
| self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any | |||
| ) -> None: | |||
| """Print out the prompts.""" | |||
| # serialized={'name': 'OpenAI'} | |||
| # prompts=['Answer the following questions...\nThought:'] | |||
| # kwargs={} | |||
| def on_llm_before_invoke(self, prompt_messages: list[PromptMessage]) -> None: | |||
| if not self._current_loop: | |||
| # Agent start with a LLM query | |||
| self._current_loop = AgentLoop( | |||
| position=len(self._agent_loops) + 1, | |||
| prompt=prompts[0], | |||
| prompt="\n".join([prompt_message.content for prompt_message in prompt_messages]), | |||
| status='llm_started', | |||
| started_at=time.perf_counter() | |||
| ) | |||
| def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: | |||
| """Do nothing.""" | |||
| # kwargs={} | |||
| def on_llm_after_invoke(self, result: RuntimeLLMResult) -> None: | |||
| if self._current_loop and self._current_loop.status == 'llm_started': | |||
| self._current_loop.status = 'llm_end' | |||
| if response.llm_output: | |||
| self._current_loop.prompt_tokens = response.llm_output['token_usage']['prompt_tokens'] | |||
| if result.usage: | |||
| self._current_loop.prompt_tokens = result.usage.prompt_tokens | |||
| else: | |||
| self._current_loop.prompt_tokens = self.model_instance.get_num_tokens( | |||
| [PromptMessage(content=self._current_loop.prompt)] | |||
| self._current_loop.prompt_tokens = self.model_type_instance.get_num_tokens( | |||
| model=self.model_config.model, | |||
| credentials=self.model_config.credentials, | |||
| prompt_messages=[UserPromptMessage(content=self._current_loop.prompt)] | |||
| ) | |||
| completion_generation = response.generations[0][0] | |||
| if isinstance(completion_generation, ChatGeneration): | |||
| completion_message = completion_generation.message | |||
| if 'function_call' in completion_message.additional_kwargs: | |||
| self._current_loop.completion \ | |||
| = json.dumps({'function_call': completion_message.additional_kwargs['function_call']}) | |||
| else: | |||
| self._current_loop.completion = response.generations[0][0].text | |||
| completion_message = result.message | |||
| if completion_message.tool_calls: | |||
| self._current_loop.completion \ | |||
| = json.dumps({'function_call': completion_message.tool_calls}) | |||
| else: | |||
| self._current_loop.completion = completion_generation.text | |||
| self._current_loop.completion = completion_message.content | |||
| if response.llm_output: | |||
| self._current_loop.completion_tokens = response.llm_output['token_usage']['completion_tokens'] | |||
| if result.usage: | |||
| self._current_loop.completion_tokens = result.usage.completion_tokens | |||
| else: | |||
| self._current_loop.completion_tokens = self.model_instance.get_num_tokens( | |||
| [PromptMessage(content=self._current_loop.completion)] | |||
| self._current_loop.completion_tokens = self.model_type_instance.get_num_tokens( | |||
| model=self.model_config.model, | |||
| credentials=self.model_config.credentials, | |||
| prompt_messages=[AssistantPromptMessage(content=self._current_loop.completion)] | |||
| ) | |||
| def on_chat_model_start( | |||
| self, | |||
| serialized: Dict[str, Any], | |||
| messages: List[List[BaseMessage]], | |||
| **kwargs: Any | |||
| ) -> Any: | |||
| pass | |||
| def on_llm_start( | |||
| self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any | |||
| ) -> None: | |||
| pass | |||
| def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: | |||
| """Do nothing.""" | |||
| pass | |||
| def on_llm_error( | |||
| self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any | |||
| ) -> None: | |||
| @@ -150,10 +155,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler): | |||
| if completion is not None: | |||
| self._current_loop.completion = completion | |||
| self._message_agent_thought = self.conversation_message_task.on_agent_start( | |||
| self.current_chain, | |||
| self._current_loop | |||
| ) | |||
| self._message_agent_thought = self._init_agent_thought() | |||
| def on_tool_end( | |||
| self, | |||
| @@ -176,9 +178,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler): | |||
| self._current_loop.completed_at = time.perf_counter() | |||
| self._current_loop.latency = self._current_loop.completed_at - self._current_loop.started_at | |||
| self.conversation_message_task.on_agent_end( | |||
| self._message_agent_thought, self.model_instance, self._current_loop | |||
| ) | |||
| self._complete_agent_thought(self._message_agent_thought) | |||
| self._agent_loops.append(self._current_loop) | |||
| self._current_loop = None | |||
| @@ -202,17 +202,62 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler): | |||
| self._current_loop.completed_at = time.perf_counter() | |||
| self._current_loop.latency = self._current_loop.completed_at - self._current_loop.started_at | |||
| self._current_loop.thought = '[DONE]' | |||
| self._message_agent_thought = self.conversation_message_task.on_agent_start( | |||
| self.current_chain, | |||
| self._current_loop | |||
| ) | |||
| self._message_agent_thought = self._init_agent_thought() | |||
| self.conversation_message_task.on_agent_end( | |||
| self._message_agent_thought, self.model_instance, self._current_loop | |||
| ) | |||
| self._complete_agent_thought(self._message_agent_thought) | |||
| self._agent_loops.append(self._current_loop) | |||
| self._current_loop = None | |||
| self._message_agent_thought = None | |||
| elif not self._current_loop and self._agent_loops: | |||
| self._agent_loops[-1].status = 'agent_finish' | |||
| def _init_agent_thought(self) -> MessageAgentThought: | |||
| message_agent_thought = MessageAgentThought( | |||
| message_id=self.message.id, | |||
| message_chain_id=self.message_chain.id, | |||
| position=self._current_loop.position, | |||
| thought=self._current_loop.thought, | |||
| tool=self._current_loop.tool_name, | |||
| tool_input=self._current_loop.tool_input, | |||
| message=self._current_loop.prompt, | |||
| message_price_unit=0, | |||
| answer=self._current_loop.completion, | |||
| answer_price_unit=0, | |||
| created_by_role=('account' if self.message.from_source == 'console' else 'end_user'), | |||
| created_by=(self.message.from_account_id | |||
| if self.message.from_source == 'console' else self.message.from_end_user_id) | |||
| ) | |||
| db.session.add(message_agent_thought) | |||
| db.session.commit() | |||
| self.queue_manager.publish_agent_thought(message_agent_thought) | |||
| return message_agent_thought | |||
| def _complete_agent_thought(self, message_agent_thought: MessageAgentThought) -> None: | |||
| loop_message_tokens = self._current_loop.prompt_tokens | |||
| loop_answer_tokens = self._current_loop.completion_tokens | |||
| # transform usage | |||
| llm_usage = self.model_type_instance._calc_response_usage( | |||
| self.model_config.model, | |||
| self.model_config.credentials, | |||
| loop_message_tokens, | |||
| loop_answer_tokens | |||
| ) | |||
| message_agent_thought.observation = self._current_loop.tool_output | |||
| message_agent_thought.tool_process_data = '' # currently not support | |||
| message_agent_thought.message_token = loop_message_tokens | |||
| message_agent_thought.message_unit_price = llm_usage.prompt_unit_price | |||
| message_agent_thought.message_price_unit = llm_usage.prompt_price_unit | |||
| message_agent_thought.answer_token = loop_answer_tokens | |||
| message_agent_thought.answer_unit_price = llm_usage.completion_unit_price | |||
| message_agent_thought.answer_price_unit = llm_usage.completion_price_unit | |||
| message_agent_thought.latency = self._current_loop.latency | |||
| message_agent_thought.tokens = self._current_loop.prompt_tokens + self._current_loop.completion_tokens | |||
| message_agent_thought.total_price = llm_usage.total_price | |||
| message_agent_thought.currency = llm_usage.currency | |||
| db.session.commit() | |||
| @@ -1,74 +0,0 @@ | |||
| import json | |||
| import logging | |||
| from json import JSONDecodeError | |||
| from typing import Any, Dict, List, Union, Optional | |||
| from langchain.callbacks.base import BaseCallbackHandler | |||
| from core.callback_handler.entity.dataset_query import DatasetQueryObj | |||
| from core.conversation_message_task import ConversationMessageTask | |||
| class DatasetToolCallbackHandler(BaseCallbackHandler): | |||
| """Callback Handler that prints to std out.""" | |||
| raise_error: bool = True | |||
| def __init__(self, conversation_message_task: ConversationMessageTask) -> None: | |||
| """Initialize callback handler.""" | |||
| self.queries = [] | |||
| self.conversation_message_task = conversation_message_task | |||
| @property | |||
| def always_verbose(self) -> bool: | |||
| """Whether to call verbose callbacks even if verbose is False.""" | |||
| return True | |||
| @property | |||
| def ignore_llm(self) -> bool: | |||
| """Whether to ignore LLM callbacks.""" | |||
| return True | |||
| @property | |||
| def ignore_chain(self) -> bool: | |||
| """Whether to ignore chain callbacks.""" | |||
| return True | |||
| @property | |||
| def ignore_agent(self) -> bool: | |||
| """Whether to ignore agent callbacks.""" | |||
| return False | |||
| def on_tool_start( | |||
| self, | |||
| serialized: Dict[str, Any], | |||
| input_str: str, | |||
| **kwargs: Any, | |||
| ) -> None: | |||
| tool_name: str = serialized.get('name') | |||
| dataset_id = tool_name.removeprefix('dataset-') | |||
| try: | |||
| input_dict = json.loads(input_str.replace("'", "\"")) | |||
| query = input_dict.get('query') | |||
| except JSONDecodeError: | |||
| query = input_str | |||
| self.conversation_message_task.on_dataset_query_end(DatasetQueryObj(dataset_id=dataset_id, query=query)) | |||
| def on_tool_end( | |||
| self, | |||
| output: str, | |||
| color: Optional[str] = None, | |||
| observation_prefix: Optional[str] = None, | |||
| llm_prefix: Optional[str] = None, | |||
| **kwargs: Any, | |||
| ) -> None: | |||
| pass | |||
| def on_tool_error( | |||
| self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any | |||
| ) -> None: | |||
| """Do nothing.""" | |||
| logging.debug("Dataset tool on_llm_error: %s", error) | |||
| @@ -1,16 +0,0 @@ | |||
| from pydantic import BaseModel | |||
| class ChainResult(BaseModel): | |||
| type: str = None | |||
| prompt: dict = None | |||
| completion: dict = None | |||
| status: str = 'chain_started' | |||
| completed: bool = False | |||
| started_at: float = None | |||
| completed_at: float = None | |||
| agent_result: dict = None | |||
| """only when type is 'AgentExecutor'""" | |||
| @@ -1,6 +0,0 @@ | |||
| from pydantic import BaseModel | |||
| class DatasetQueryObj(BaseModel): | |||
| dataset_id: str = None | |||
| query: str = None | |||
| @@ -1,8 +0,0 @@ | |||
| from pydantic import BaseModel | |||
| class LLMMessage(BaseModel): | |||
| prompt: str = '' | |||
| prompt_tokens: int = 0 | |||
| completion: str = '' | |||
| completion_tokens: int = 0 | |||
| @@ -1,17 +1,44 @@ | |||
| from typing import List | |||
| from typing import List, Union | |||
| from langchain.schema import Document | |||
| from core.conversation_message_task import ConversationMessageTask | |||
| from core.application_queue_manager import ApplicationQueueManager | |||
| from core.entities.application_entities import InvokeFrom | |||
| from extensions.ext_database import db | |||
| from models.dataset import DocumentSegment | |||
| from models.dataset import DocumentSegment, DatasetQuery | |||
| from models.model import DatasetRetrieverResource | |||
| class DatasetIndexToolCallbackHandler: | |||
| """Callback handler for dataset tool.""" | |||
| def __init__(self, conversation_message_task: ConversationMessageTask) -> None: | |||
| self.conversation_message_task = conversation_message_task | |||
| def __init__(self, queue_manager: ApplicationQueueManager, | |||
| app_id: str, | |||
| message_id: str, | |||
| user_id: str, | |||
| invoke_from: InvokeFrom) -> None: | |||
| self._queue_manager = queue_manager | |||
| self._app_id = app_id | |||
| self._message_id = message_id | |||
| self._user_id = user_id | |||
| self._invoke_from = invoke_from | |||
| def on_query(self, query: str, dataset_id: str) -> None: | |||
| """ | |||
| Handle query. | |||
| """ | |||
| dataset_query = DatasetQuery( | |||
| dataset_id=dataset_id, | |||
| content=query, | |||
| source='app', | |||
| source_app_id=self._app_id, | |||
| created_by_role=('account' | |||
| if self._invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end_user'), | |||
| created_by=self._user_id | |||
| ) | |||
| db.session.add(dataset_query) | |||
| db.session.commit() | |||
| def on_tool_end(self, documents: List[Document]) -> None: | |||
| """Handle tool end.""" | |||
| @@ -30,4 +57,27 @@ class DatasetIndexToolCallbackHandler: | |||
| def return_retriever_resource_info(self, resource: List): | |||
| """Handle return_retriever_resource_info.""" | |||
| self.conversation_message_task.on_dataset_query_finish(resource) | |||
| if resource and len(resource) > 0: | |||
| for item in resource: | |||
| dataset_retriever_resource = DatasetRetrieverResource( | |||
| message_id=self._message_id, | |||
| position=item.get('position'), | |||
| dataset_id=item.get('dataset_id'), | |||
| dataset_name=item.get('dataset_name'), | |||
| document_id=item.get('document_id'), | |||
| document_name=item.get('document_name'), | |||
| data_source_type=item.get('data_source_type'), | |||
| segment_id=item.get('segment_id'), | |||
| score=item.get('score') if 'score' in item else None, | |||
| hit_count=item.get('hit_count') if 'hit_count' else None, | |||
| word_count=item.get('word_count') if 'word_count' in item else None, | |||
| segment_position=item.get('segment_position') if 'segment_position' in item else None, | |||
| index_node_hash=item.get('index_node_hash') if 'index_node_hash' in item else None, | |||
| content=item.get('content'), | |||
| retriever_from=item.get('retriever_from'), | |||
| created_by=self._user_id | |||
| ) | |||
| db.session.add(dataset_retriever_resource) | |||
| db.session.commit() | |||
| self._queue_manager.publish_retriever_resources(resource) | |||
| @@ -1,284 +0,0 @@ | |||
| import logging | |||
| import threading | |||
| import time | |||
| from typing import Any, Dict, List, Union, Optional | |||
| from flask import Flask, current_app | |||
| from langchain.callbacks.base import BaseCallbackHandler | |||
| from langchain.schema import LLMResult, BaseMessage | |||
| from pydantic import BaseModel | |||
| from core.callback_handler.entity.llm_message import LLMMessage | |||
| from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException, \ | |||
| ConversationTaskInterruptException | |||
| from core.model_providers.models.entity.message import to_prompt_messages, PromptMessage, LCHumanMessageWithFiles, \ | |||
| ImagePromptMessageFile | |||
| from core.model_providers.models.llm.base import BaseLLM | |||
| from core.moderation.base import ModerationOutputsResult, ModerationAction | |||
| from core.moderation.factory import ModerationFactory | |||
| class ModerationRule(BaseModel): | |||
| type: str | |||
| config: Dict[str, Any] | |||
| class LLMCallbackHandler(BaseCallbackHandler): | |||
| raise_error: bool = True | |||
| def __init__(self, model_instance: BaseLLM, | |||
| conversation_message_task: ConversationMessageTask): | |||
| self.model_instance = model_instance | |||
| self.llm_message = LLMMessage() | |||
| self.start_at = None | |||
| self.conversation_message_task = conversation_message_task | |||
| self.output_moderation_handler = None | |||
| self.init_output_moderation() | |||
| def init_output_moderation(self): | |||
| app_model_config = self.conversation_message_task.app_model_config | |||
| sensitive_word_avoidance_dict = app_model_config.sensitive_word_avoidance_dict | |||
| if sensitive_word_avoidance_dict and sensitive_word_avoidance_dict.get("enabled"): | |||
| self.output_moderation_handler = OutputModerationHandler( | |||
| tenant_id=self.conversation_message_task.tenant_id, | |||
| app_id=self.conversation_message_task.app.id, | |||
| rule=ModerationRule( | |||
| type=sensitive_word_avoidance_dict.get("type"), | |||
| config=sensitive_word_avoidance_dict.get("config") | |||
| ), | |||
| on_message_replace_func=self.conversation_message_task.on_message_replace | |||
| ) | |||
| @property | |||
| def always_verbose(self) -> bool: | |||
| """Whether to call verbose callbacks even if verbose is False.""" | |||
| return True | |||
| def on_chat_model_start( | |||
| self, | |||
| serialized: Dict[str, Any], | |||
| messages: List[List[BaseMessage]], | |||
| **kwargs: Any | |||
| ) -> Any: | |||
| real_prompts = [] | |||
| for message in messages[0]: | |||
| if message.type == 'human': | |||
| role = 'user' | |||
| elif message.type == 'ai': | |||
| role = 'assistant' | |||
| else: | |||
| role = 'system' | |||
| real_prompts.append({ | |||
| "role": role, | |||
| "text": message.content, | |||
| "files": [{ | |||
| "type": file.type.value, | |||
| "data": file.data[:10] + '...[TRUNCATED]...' + file.data[-10:], | |||
| "detail": file.detail.value if isinstance(file, ImagePromptMessageFile) else None, | |||
| } for file in (message.files if isinstance(message, LCHumanMessageWithFiles) else [])] | |||
| }) | |||
| self.llm_message.prompt = real_prompts | |||
| self.llm_message.prompt_tokens = self.model_instance.get_num_tokens(to_prompt_messages(messages[0])) | |||
| def on_llm_start( | |||
| self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any | |||
| ) -> None: | |||
| self.llm_message.prompt = [{ | |||
| "role": 'user', | |||
| "text": prompts[0] | |||
| }] | |||
| self.llm_message.prompt_tokens = self.model_instance.get_num_tokens([PromptMessage(content=prompts[0])]) | |||
| def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: | |||
| if self.output_moderation_handler: | |||
| self.output_moderation_handler.stop_thread() | |||
| self.llm_message.completion = self.output_moderation_handler.moderation_completion( | |||
| completion=response.generations[0][0].text, | |||
| public_event=True if self.conversation_message_task.streaming else False | |||
| ) | |||
| else: | |||
| self.llm_message.completion = response.generations[0][0].text | |||
| if not self.conversation_message_task.streaming: | |||
| self.conversation_message_task.append_message_text(self.llm_message.completion) | |||
| if response.llm_output and 'token_usage' in response.llm_output: | |||
| if 'prompt_tokens' in response.llm_output['token_usage']: | |||
| self.llm_message.prompt_tokens = response.llm_output['token_usage']['prompt_tokens'] | |||
| if 'completion_tokens' in response.llm_output['token_usage']: | |||
| self.llm_message.completion_tokens = response.llm_output['token_usage']['completion_tokens'] | |||
| else: | |||
| self.llm_message.completion_tokens = self.model_instance.get_num_tokens( | |||
| [PromptMessage(content=self.llm_message.completion)]) | |||
| else: | |||
| self.llm_message.completion_tokens = self.model_instance.get_num_tokens( | |||
| [PromptMessage(content=self.llm_message.completion)]) | |||
| self.conversation_message_task.save_message(self.llm_message) | |||
| def on_llm_new_token(self, token: str, **kwargs: Any) -> None: | |||
| if self.output_moderation_handler and self.output_moderation_handler.should_direct_output(): | |||
| # stop subscribe new token when output moderation should direct output | |||
| ex = ConversationTaskInterruptException() | |||
| self.on_llm_error(error=ex) | |||
| raise ex | |||
| try: | |||
| self.conversation_message_task.append_message_text(token) | |||
| self.llm_message.completion += token | |||
| if self.output_moderation_handler: | |||
| self.output_moderation_handler.append_new_token(token) | |||
| except ConversationTaskStoppedException as ex: | |||
| self.on_llm_error(error=ex) | |||
| raise ex | |||
| def on_llm_error( | |||
| self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any | |||
| ) -> None: | |||
| """Do nothing.""" | |||
| if self.output_moderation_handler: | |||
| self.output_moderation_handler.stop_thread() | |||
| if isinstance(error, ConversationTaskStoppedException): | |||
| if self.conversation_message_task.streaming: | |||
| self.llm_message.completion_tokens = self.model_instance.get_num_tokens( | |||
| [PromptMessage(content=self.llm_message.completion)] | |||
| ) | |||
| self.conversation_message_task.save_message(llm_message=self.llm_message, by_stopped=True) | |||
| if isinstance(error, ConversationTaskInterruptException): | |||
| self.llm_message.completion = self.output_moderation_handler.get_final_output() | |||
| self.llm_message.completion_tokens = self.model_instance.get_num_tokens( | |||
| [PromptMessage(content=self.llm_message.completion)] | |||
| ) | |||
| self.conversation_message_task.save_message(llm_message=self.llm_message) | |||
| else: | |||
| logging.debug("on_llm_error: %s", error) | |||
| class OutputModerationHandler(BaseModel): | |||
| DEFAULT_BUFFER_SIZE: int = 300 | |||
| tenant_id: str | |||
| app_id: str | |||
| rule: ModerationRule | |||
| on_message_replace_func: Any | |||
| thread: Optional[threading.Thread] = None | |||
| thread_running: bool = True | |||
| buffer: str = '' | |||
| is_final_chunk: bool = False | |||
| final_output: Optional[str] = None | |||
| class Config: | |||
| arbitrary_types_allowed = True | |||
| def should_direct_output(self): | |||
| return self.final_output is not None | |||
| def get_final_output(self): | |||
| return self.final_output | |||
| def append_new_token(self, token: str): | |||
| self.buffer += token | |||
| if not self.thread: | |||
| self.thread = self.start_thread() | |||
| def moderation_completion(self, completion: str, public_event: bool = False) -> str: | |||
| self.buffer = completion | |||
| self.is_final_chunk = True | |||
| result = self.moderation( | |||
| tenant_id=self.tenant_id, | |||
| app_id=self.app_id, | |||
| moderation_buffer=completion | |||
| ) | |||
| if not result or not result.flagged: | |||
| return completion | |||
| if result.action == ModerationAction.DIRECT_OUTPUT: | |||
| final_output = result.preset_response | |||
| else: | |||
| final_output = result.text | |||
| if public_event: | |||
| self.on_message_replace_func(final_output) | |||
| return final_output | |||
| def start_thread(self) -> threading.Thread: | |||
| buffer_size = int(current_app.config.get('MODERATION_BUFFER_SIZE', self.DEFAULT_BUFFER_SIZE)) | |||
| thread = threading.Thread(target=self.worker, kwargs={ | |||
| 'flask_app': current_app._get_current_object(), | |||
| 'buffer_size': buffer_size if buffer_size > 0 else self.DEFAULT_BUFFER_SIZE | |||
| }) | |||
| thread.start() | |||
| return thread | |||
| def stop_thread(self): | |||
| if self.thread and self.thread.is_alive(): | |||
| self.thread_running = False | |||
| def worker(self, flask_app: Flask, buffer_size: int): | |||
| with flask_app.app_context(): | |||
| current_length = 0 | |||
| while self.thread_running: | |||
| moderation_buffer = self.buffer | |||
| buffer_length = len(moderation_buffer) | |||
| if not self.is_final_chunk: | |||
| chunk_length = buffer_length - current_length | |||
| if 0 <= chunk_length < buffer_size: | |||
| time.sleep(1) | |||
| continue | |||
| current_length = buffer_length | |||
| result = self.moderation( | |||
| tenant_id=self.tenant_id, | |||
| app_id=self.app_id, | |||
| moderation_buffer=moderation_buffer | |||
| ) | |||
| if not result or not result.flagged: | |||
| continue | |||
| if result.action == ModerationAction.DIRECT_OUTPUT: | |||
| final_output = result.preset_response | |||
| self.final_output = final_output | |||
| else: | |||
| final_output = result.text + self.buffer[len(moderation_buffer):] | |||
| # trigger replace event | |||
| if self.thread_running: | |||
| self.on_message_replace_func(final_output) | |||
| if result.action == ModerationAction.DIRECT_OUTPUT: | |||
| break | |||
| def moderation(self, tenant_id: str, app_id: str, moderation_buffer: str) -> Optional[ModerationOutputsResult]: | |||
| try: | |||
| moderation_factory = ModerationFactory( | |||
| name=self.rule.type, | |||
| app_id=app_id, | |||
| tenant_id=tenant_id, | |||
| config=self.rule.config | |||
| ) | |||
| result: ModerationOutputsResult = moderation_factory.moderation_for_outputs(moderation_buffer) | |||
| return result | |||
| except Exception as e: | |||
| logging.error("Moderation Output error: %s", e) | |||
| return None | |||
| @@ -1,76 +0,0 @@ | |||
| import logging | |||
| import time | |||
| from typing import Any, Dict, Union | |||
| from langchain.callbacks.base import BaseCallbackHandler | |||
| from core.callback_handler.entity.chain_result import ChainResult | |||
| from core.conversation_message_task import ConversationMessageTask | |||
| class MainChainGatherCallbackHandler(BaseCallbackHandler): | |||
| """Callback Handler that prints to std out.""" | |||
| raise_error: bool = True | |||
| def __init__(self, conversation_message_task: ConversationMessageTask) -> None: | |||
| """Initialize callback handler.""" | |||
| self._current_chain_result = None | |||
| self._current_chain_message = None | |||
| self.conversation_message_task = conversation_message_task | |||
| self.agent_callback = None | |||
| def clear_chain_results(self) -> None: | |||
| self._current_chain_result = None | |||
| self._current_chain_message = None | |||
| if self.agent_callback: | |||
| self.agent_callback.current_chain = None | |||
| @property | |||
| def always_verbose(self) -> bool: | |||
| """Whether to call verbose callbacks even if verbose is False.""" | |||
| return True | |||
| @property | |||
| def ignore_llm(self) -> bool: | |||
| """Whether to ignore LLM callbacks.""" | |||
| return True | |||
| @property | |||
| def ignore_agent(self) -> bool: | |||
| """Whether to ignore agent callbacks.""" | |||
| return True | |||
| def on_chain_start( | |||
| self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any | |||
| ) -> None: | |||
| """Print out that we are entering a chain.""" | |||
| if not self._current_chain_result: | |||
| chain_type = serialized['id'][-1] | |||
| if chain_type: | |||
| self._current_chain_result = ChainResult( | |||
| type=chain_type, | |||
| prompt=inputs, | |||
| started_at=time.perf_counter() | |||
| ) | |||
| self._current_chain_message = self.conversation_message_task.init_chain(self._current_chain_result) | |||
| if self.agent_callback: | |||
| self.agent_callback.current_chain = self._current_chain_message | |||
| def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: | |||
| """Print out that we finished a chain.""" | |||
| if self._current_chain_result and self._current_chain_result.status == 'chain_started': | |||
| self._current_chain_result.status = 'chain_ended' | |||
| self._current_chain_result.completion = outputs | |||
| self._current_chain_result.completed = True | |||
| self._current_chain_result.completed_at = time.perf_counter() | |||
| self.conversation_message_task.on_chain_end(self._current_chain_message, self._current_chain_result) | |||
| self.clear_chain_results() | |||
| def on_chain_error( | |||
| self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any | |||
| ) -> None: | |||
| logging.debug("Dataset tool on_chain_error: %s", error) | |||
| self.clear_chain_results() | |||
| @@ -79,8 +79,11 @@ class DifyStdOutCallbackHandler(BaseCallbackHandler): | |||
| """Run on agent action.""" | |||
| tool = action.tool | |||
| tool_input = action.tool_input | |||
| action_name_position = action.log.index("\nAction:") + 1 if action.log else -1 | |||
| thought = action.log[:action_name_position].strip() if action.log else '' | |||
| try: | |||
| action_name_position = action.log.index("\nAction:") + 1 if action.log else -1 | |||
| thought = action.log[:action_name_position].strip() if action.log else '' | |||
| except ValueError: | |||
| thought = '' | |||
| log = f"Thought: {thought}\nTool: {tool}\nTool Input: {tool_input}" | |||
| print_text("\n[on_agent_action]\n" + log + "\n", color='green') | |||
| @@ -5,15 +5,19 @@ from langchain.callbacks.manager import CallbackManagerForChainRun | |||
| from langchain.schema import LLMResult, Generation | |||
| from langchain.schema.language_model import BaseLanguageModel | |||
| from core.model_providers.models.entity.message import to_prompt_messages | |||
| from core.model_providers.models.llm.base import BaseLLM | |||
| from core.agent.agent.agent_llm_callback import AgentLLMCallback | |||
| from core.entities.application_entities import ModelConfigEntity | |||
| from core.model_manager import ModelInstance | |||
| from core.entities.message_entities import lc_messages_to_prompt_messages | |||
| from core.third_party.langchain.llms.fake import FakeLLM | |||
| class LLMChain(LCLLMChain): | |||
| model_instance: BaseLLM | |||
| model_config: ModelConfigEntity | |||
| """The language model instance to use.""" | |||
| llm: BaseLanguageModel = FakeLLM(response="") | |||
| parameters: Dict[str, Any] = {} | |||
| agent_llm_callback: Optional[AgentLLMCallback] = None | |||
| def generate( | |||
| self, | |||
| @@ -23,14 +27,23 @@ class LLMChain(LCLLMChain): | |||
| """Generate LLM result from inputs.""" | |||
| prompts, stop = self.prep_prompts(input_list, run_manager=run_manager) | |||
| messages = prompts[0].to_messages() | |||
| prompt_messages = to_prompt_messages(messages) | |||
| result = self.model_instance.run( | |||
| messages=prompt_messages, | |||
| stop=stop | |||
| prompt_messages = lc_messages_to_prompt_messages(messages) | |||
| model_instance = ModelInstance( | |||
| provider_model_bundle=self.model_config.provider_model_bundle, | |||
| model=self.model_config.model, | |||
| ) | |||
| result = model_instance.invoke_llm( | |||
| prompt_messages=prompt_messages, | |||
| stream=False, | |||
| stop=stop, | |||
| callbacks=[self.agent_llm_callback] if self.agent_llm_callback else None, | |||
| model_parameters=self.parameters | |||
| ) | |||
| generations = [ | |||
| [Generation(text=result.content)] | |||
| [Generation(text=result.message.content)] | |||
| ] | |||
| return LLMResult(generations=generations) | |||
| @@ -1,501 +0,0 @@ | |||
| import concurrent | |||
| import json | |||
| import logging | |||
| from concurrent.futures import ThreadPoolExecutor | |||
| from typing import Optional, List, Union, Tuple | |||
| from flask import current_app, Flask | |||
| from requests.exceptions import ChunkedEncodingError | |||
| from core.agent.agent_executor import AgentExecuteResult, PlanningStrategy | |||
| from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler | |||
| from core.callback_handler.llm_callback_handler import LLMCallbackHandler | |||
| from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException, \ | |||
| ConversationTaskInterruptException | |||
| from core.embedding.cached_embedding import CacheEmbedding | |||
| from core.external_data_tool.factory import ExternalDataToolFactory | |||
| from core.file.file_obj import FileObj | |||
| from core.index.vector_index.vector_index import VectorIndex | |||
| from core.model_providers.error import LLMBadRequestError | |||
| from core.memory.read_only_conversation_token_db_buffer_shared_memory import \ | |||
| ReadOnlyConversationTokenDBBufferSharedMemory | |||
| from core.model_providers.model_factory import ModelFactory | |||
| from core.model_providers.models.entity.message import PromptMessage, PromptMessageFile | |||
| from core.model_providers.models.llm.base import BaseLLM | |||
| from core.orchestrator_rule_parser import OrchestratorRuleParser | |||
| from core.prompt.prompt_template import PromptTemplateParser | |||
| from core.prompt.prompt_transform import PromptTransform | |||
| from models.dataset import Dataset | |||
| from models.model import App, AppModelConfig, Account, Conversation, EndUser | |||
| from core.moderation.base import ModerationException, ModerationAction | |||
| from core.moderation.factory import ModerationFactory | |||
| from services.annotation_service import AppAnnotationService | |||
| from services.dataset_service import DatasetCollectionBindingService | |||
| class Completion: | |||
| @classmethod | |||
| def generate(cls, task_id: str, app: App, app_model_config: AppModelConfig, query: str, inputs: dict, | |||
| files: List[FileObj], user: Union[Account, EndUser], conversation: Optional[Conversation], | |||
| streaming: bool, is_override: bool = False, retriever_from: str = 'dev', | |||
| auto_generate_name: bool = True, from_source: str = 'console'): | |||
| """ | |||
| errors: ProviderTokenNotInitError | |||
| """ | |||
| query = PromptTemplateParser.remove_template_variables(query) | |||
| memory = None | |||
| if conversation: | |||
| # get memory of conversation (read-only) | |||
| memory = cls.get_memory_from_conversation( | |||
| tenant_id=app.tenant_id, | |||
| app_model_config=app_model_config, | |||
| conversation=conversation, | |||
| return_messages=False | |||
| ) | |||
| inputs = conversation.inputs | |||
| final_model_instance = ModelFactory.get_text_generation_model_from_model_config( | |||
| tenant_id=app.tenant_id, | |||
| model_config=app_model_config.model_dict, | |||
| streaming=streaming | |||
| ) | |||
| conversation_message_task = ConversationMessageTask( | |||
| task_id=task_id, | |||
| app=app, | |||
| app_model_config=app_model_config, | |||
| user=user, | |||
| conversation=conversation, | |||
| is_override=is_override, | |||
| inputs=inputs, | |||
| query=query, | |||
| files=files, | |||
| streaming=streaming, | |||
| model_instance=final_model_instance, | |||
| auto_generate_name=auto_generate_name | |||
| ) | |||
| prompt_message_files = [file.prompt_message_file for file in files] | |||
| rest_tokens_for_context_and_memory = cls.get_validate_rest_tokens( | |||
| mode=app.mode, | |||
| model_instance=final_model_instance, | |||
| app_model_config=app_model_config, | |||
| query=query, | |||
| inputs=inputs, | |||
| files=prompt_message_files | |||
| ) | |||
| # init orchestrator rule parser | |||
| orchestrator_rule_parser = OrchestratorRuleParser( | |||
| tenant_id=app.tenant_id, | |||
| app_model_config=app_model_config | |||
| ) | |||
| try: | |||
| chain_callback = MainChainGatherCallbackHandler(conversation_message_task) | |||
| try: | |||
| # process sensitive_word_avoidance | |||
| inputs, query = cls.moderation_for_inputs(app.id, app.tenant_id, app_model_config, inputs, query) | |||
| except ModerationException as e: | |||
| cls.run_final_llm( | |||
| model_instance=final_model_instance, | |||
| mode=app.mode, | |||
| app_model_config=app_model_config, | |||
| query=query, | |||
| inputs=inputs, | |||
| files=prompt_message_files, | |||
| agent_execute_result=None, | |||
| conversation_message_task=conversation_message_task, | |||
| memory=memory, | |||
| fake_response=str(e) | |||
| ) | |||
| return | |||
| # check annotation reply | |||
| annotation_reply = cls.query_app_annotations_to_reply(conversation_message_task, from_source) | |||
| if annotation_reply: | |||
| return | |||
| # fill in variable inputs from external data tools if exists | |||
| external_data_tools = app_model_config.external_data_tools_list | |||
| if external_data_tools: | |||
| inputs = cls.fill_in_inputs_from_external_data_tools( | |||
| tenant_id=app.tenant_id, | |||
| app_id=app.id, | |||
| external_data_tools=external_data_tools, | |||
| inputs=inputs, | |||
| query=query | |||
| ) | |||
| # get agent executor | |||
| agent_executor = orchestrator_rule_parser.to_agent_executor( | |||
| conversation_message_task=conversation_message_task, | |||
| memory=memory, | |||
| rest_tokens=rest_tokens_for_context_and_memory, | |||
| chain_callback=chain_callback, | |||
| tenant_id=app.tenant_id, | |||
| retriever_from=retriever_from | |||
| ) | |||
| query_for_agent = cls.get_query_for_agent(app, app_model_config, query, inputs) | |||
| # run agent executor | |||
| agent_execute_result = None | |||
| if query_for_agent and agent_executor: | |||
| should_use_agent = agent_executor.should_use_agent(query_for_agent) | |||
| if should_use_agent: | |||
| agent_execute_result = agent_executor.run(query_for_agent) | |||
| # When no extra pre prompt is specified, | |||
| # the output of the agent can be used directly as the main output content without calling LLM again | |||
| fake_response = None | |||
| if not app_model_config.pre_prompt and agent_execute_result and agent_execute_result.output \ | |||
| and agent_execute_result.strategy not in [PlanningStrategy.ROUTER, | |||
| PlanningStrategy.REACT_ROUTER]: | |||
| fake_response = agent_execute_result.output | |||
| # run the final llm | |||
| cls.run_final_llm( | |||
| model_instance=final_model_instance, | |||
| mode=app.mode, | |||
| app_model_config=app_model_config, | |||
| query=query, | |||
| inputs=inputs, | |||
| files=prompt_message_files, | |||
| agent_execute_result=agent_execute_result, | |||
| conversation_message_task=conversation_message_task, | |||
| memory=memory, | |||
| fake_response=fake_response | |||
| ) | |||
| except (ConversationTaskInterruptException, ConversationTaskStoppedException): | |||
| return | |||
| except ChunkedEncodingError as e: | |||
| # Interrupt by LLM (like OpenAI), handle it. | |||
| logging.warning(f'ChunkedEncodingError: {e}') | |||
| return | |||
| @classmethod | |||
| def moderation_for_inputs(cls, app_id: str, tenant_id: str, app_model_config: AppModelConfig, inputs: dict, | |||
| query: str): | |||
| if not app_model_config.sensitive_word_avoidance_dict['enabled']: | |||
| return inputs, query | |||
| type = app_model_config.sensitive_word_avoidance_dict['type'] | |||
| moderation = ModerationFactory(type, app_id, tenant_id, | |||
| app_model_config.sensitive_word_avoidance_dict['config']) | |||
| moderation_result = moderation.moderation_for_inputs(inputs, query) | |||
| if not moderation_result.flagged: | |||
| return inputs, query | |||
| if moderation_result.action == ModerationAction.DIRECT_OUTPUT: | |||
| raise ModerationException(moderation_result.preset_response) | |||
| elif moderation_result.action == ModerationAction.OVERRIDED: | |||
| inputs = moderation_result.inputs | |||
| query = moderation_result.query | |||
| return inputs, query | |||
| @classmethod | |||
| def fill_in_inputs_from_external_data_tools(cls, tenant_id: str, app_id: str, external_data_tools: list[dict], | |||
| inputs: dict, query: str) -> dict: | |||
| """ | |||
| Fill in variable inputs from external data tools if exists. | |||
| :param tenant_id: workspace id | |||
| :param app_id: app id | |||
| :param external_data_tools: external data tools configs | |||
| :param inputs: the inputs | |||
| :param query: the query | |||
| :return: the filled inputs | |||
| """ | |||
| # Group tools by type and config | |||
| grouped_tools = {} | |||
| for tool in external_data_tools: | |||
| if not tool.get("enabled"): | |||
| continue | |||
| tool_key = (tool.get("type"), json.dumps(tool.get("config"), sort_keys=True)) | |||
| grouped_tools.setdefault(tool_key, []).append(tool) | |||
| results = {} | |||
| with ThreadPoolExecutor() as executor: | |||
| futures = {} | |||
| for tool in external_data_tools: | |||
| if not tool.get("enabled"): | |||
| continue | |||
| future = executor.submit( | |||
| cls.query_external_data_tool, current_app._get_current_object(), tenant_id, app_id, tool, | |||
| inputs, query | |||
| ) | |||
| futures[future] = tool | |||
| for future in concurrent.futures.as_completed(futures): | |||
| tool_variable, result = future.result() | |||
| results[tool_variable] = result | |||
| inputs.update(results) | |||
| return inputs | |||
| @classmethod | |||
| def query_external_data_tool(cls, flask_app: Flask, tenant_id: str, app_id: str, external_data_tool: dict, | |||
| inputs: dict, query: str) -> Tuple[Optional[str], Optional[str]]: | |||
| with flask_app.app_context(): | |||
| tool_variable = external_data_tool.get("variable") | |||
| tool_type = external_data_tool.get("type") | |||
| tool_config = external_data_tool.get("config") | |||
| external_data_tool_factory = ExternalDataToolFactory( | |||
| name=tool_type, | |||
| tenant_id=tenant_id, | |||
| app_id=app_id, | |||
| variable=tool_variable, | |||
| config=tool_config | |||
| ) | |||
| # query external data tool | |||
| result = external_data_tool_factory.query( | |||
| inputs=inputs, | |||
| query=query | |||
| ) | |||
| return tool_variable, result | |||
| @classmethod | |||
| def get_query_for_agent(cls, app: App, app_model_config: AppModelConfig, query: str, inputs: dict) -> str: | |||
| if app.mode != 'completion': | |||
| return query | |||
| return inputs.get(app_model_config.dataset_query_variable, "") | |||
| @classmethod | |||
| def run_final_llm(cls, model_instance: BaseLLM, mode: str, app_model_config: AppModelConfig, query: str, | |||
| inputs: dict, | |||
| files: List[PromptMessageFile], | |||
| agent_execute_result: Optional[AgentExecuteResult], | |||
| conversation_message_task: ConversationMessageTask, | |||
| memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory], | |||
| fake_response: Optional[str]): | |||
| prompt_transform = PromptTransform() | |||
| # get llm prompt | |||
| if app_model_config.prompt_type == 'simple': | |||
| prompt_messages, stop_words = prompt_transform.get_prompt( | |||
| app_mode=mode, | |||
| pre_prompt=app_model_config.pre_prompt, | |||
| inputs=inputs, | |||
| query=query, | |||
| files=files, | |||
| context=agent_execute_result.output if agent_execute_result else None, | |||
| memory=memory, | |||
| model_instance=model_instance | |||
| ) | |||
| else: | |||
| prompt_messages = prompt_transform.get_advanced_prompt( | |||
| app_mode=mode, | |||
| app_model_config=app_model_config, | |||
| inputs=inputs, | |||
| query=query, | |||
| files=files, | |||
| context=agent_execute_result.output if agent_execute_result else None, | |||
| memory=memory, | |||
| model_instance=model_instance | |||
| ) | |||
| model_config = app_model_config.model_dict | |||
| completion_params = model_config.get("completion_params", {}) | |||
| stop_words = completion_params.get("stop", []) | |||
| cls.recale_llm_max_tokens( | |||
| model_instance=model_instance, | |||
| prompt_messages=prompt_messages, | |||
| ) | |||
| response = model_instance.run( | |||
| messages=prompt_messages, | |||
| stop=stop_words if stop_words else None, | |||
| callbacks=[LLMCallbackHandler(model_instance, conversation_message_task)], | |||
| fake_response=fake_response | |||
| ) | |||
| return response | |||
| @classmethod | |||
| def get_history_messages_from_memory(cls, memory: ReadOnlyConversationTokenDBBufferSharedMemory, | |||
| max_token_limit: int) -> str: | |||
| """Get memory messages.""" | |||
| memory.max_token_limit = max_token_limit | |||
| memory_key = memory.memory_variables[0] | |||
| external_context = memory.load_memory_variables({}) | |||
| return external_context[memory_key] | |||
| @classmethod | |||
| def query_app_annotations_to_reply(cls, conversation_message_task: ConversationMessageTask, | |||
| from_source: str) -> bool: | |||
| """Get memory messages.""" | |||
| app_model_config = conversation_message_task.app_model_config | |||
| app = conversation_message_task.app | |||
| annotation_reply = app_model_config.annotation_reply_dict | |||
| if annotation_reply['enabled']: | |||
| try: | |||
| score_threshold = annotation_reply.get('score_threshold', 1) | |||
| embedding_provider_name = annotation_reply['embedding_model']['embedding_provider_name'] | |||
| embedding_model_name = annotation_reply['embedding_model']['embedding_model_name'] | |||
| # get embedding model | |||
| embedding_model = ModelFactory.get_embedding_model( | |||
| tenant_id=app.tenant_id, | |||
| model_provider_name=embedding_provider_name, | |||
| model_name=embedding_model_name | |||
| ) | |||
| embeddings = CacheEmbedding(embedding_model) | |||
| dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( | |||
| embedding_provider_name, | |||
| embedding_model_name, | |||
| 'annotation' | |||
| ) | |||
| dataset = Dataset( | |||
| id=app.id, | |||
| tenant_id=app.tenant_id, | |||
| indexing_technique='high_quality', | |||
| embedding_model_provider=embedding_provider_name, | |||
| embedding_model=embedding_model_name, | |||
| collection_binding_id=dataset_collection_binding.id | |||
| ) | |||
| vector_index = VectorIndex( | |||
| dataset=dataset, | |||
| config=current_app.config, | |||
| embeddings=embeddings, | |||
| attributes=['doc_id', 'annotation_id', 'app_id'] | |||
| ) | |||
| documents = vector_index.search( | |||
| conversation_message_task.query, | |||
| search_type='similarity_score_threshold', | |||
| search_kwargs={ | |||
| 'k': 1, | |||
| 'score_threshold': score_threshold, | |||
| 'filter': { | |||
| 'group_id': [dataset.id] | |||
| } | |||
| } | |||
| ) | |||
| if documents: | |||
| annotation_id = documents[0].metadata['annotation_id'] | |||
| score = documents[0].metadata['score'] | |||
| annotation = AppAnnotationService.get_annotation_by_id(annotation_id) | |||
| if annotation: | |||
| conversation_message_task.annotation_end(annotation.content, annotation.id, annotation.account.name) | |||
| # insert annotation history | |||
| AppAnnotationService.add_annotation_history(annotation.id, | |||
| app.id, | |||
| annotation.question, | |||
| annotation.content, | |||
| conversation_message_task.query, | |||
| conversation_message_task.user.id, | |||
| conversation_message_task.message.id, | |||
| from_source, | |||
| score) | |||
| return True | |||
| except Exception as e: | |||
| logging.warning(f'Query annotation failed, exception: {str(e)}.') | |||
| return False | |||
| return False | |||
| @classmethod | |||
| def get_memory_from_conversation(cls, tenant_id: str, app_model_config: AppModelConfig, | |||
| conversation: Conversation, | |||
| **kwargs) -> ReadOnlyConversationTokenDBBufferSharedMemory: | |||
| # only for calc token in memory | |||
| memory_model_instance = ModelFactory.get_text_generation_model_from_model_config( | |||
| tenant_id=tenant_id, | |||
| model_config=app_model_config.model_dict | |||
| ) | |||
| # use llm config from conversation | |||
| memory = ReadOnlyConversationTokenDBBufferSharedMemory( | |||
| conversation=conversation, | |||
| model_instance=memory_model_instance, | |||
| max_token_limit=kwargs.get("max_token_limit", 2048), | |||
| memory_key=kwargs.get("memory_key", "chat_history"), | |||
| return_messages=kwargs.get("return_messages", True), | |||
| input_key=kwargs.get("input_key", "input"), | |||
| output_key=kwargs.get("output_key", "output"), | |||
| message_limit=kwargs.get("message_limit", 10), | |||
| ) | |||
| return memory | |||
| @classmethod | |||
| def get_validate_rest_tokens(cls, mode: str, model_instance: BaseLLM, app_model_config: AppModelConfig, | |||
| query: str, inputs: dict, files: List[PromptMessageFile]) -> int: | |||
| model_limited_tokens = model_instance.model_rules.max_tokens.max | |||
| max_tokens = model_instance.get_model_kwargs().max_tokens | |||
| if model_limited_tokens is None: | |||
| return -1 | |||
| if max_tokens is None: | |||
| max_tokens = 0 | |||
| prompt_transform = PromptTransform() | |||
| # get prompt without memory and context | |||
| if app_model_config.prompt_type == 'simple': | |||
| prompt_messages, _ = prompt_transform.get_prompt( | |||
| app_mode=mode, | |||
| pre_prompt=app_model_config.pre_prompt, | |||
| inputs=inputs, | |||
| query=query, | |||
| files=files, | |||
| context=None, | |||
| memory=None, | |||
| model_instance=model_instance | |||
| ) | |||
| else: | |||
| prompt_messages = prompt_transform.get_advanced_prompt( | |||
| app_mode=mode, | |||
| app_model_config=app_model_config, | |||
| inputs=inputs, | |||
| query=query, | |||
| files=files, | |||
| context=None, | |||
| memory=None, | |||
| model_instance=model_instance | |||
| ) | |||
| prompt_tokens = model_instance.get_num_tokens(prompt_messages) | |||
| rest_tokens = model_limited_tokens - max_tokens - prompt_tokens | |||
| if rest_tokens < 0: | |||
| raise LLMBadRequestError("Query or prefix prompt is too long, you can reduce the prefix prompt, " | |||
| "or shrink the max token, or switch to a llm with a larger token limit size.") | |||
| return rest_tokens | |||
| @classmethod | |||
| def recale_llm_max_tokens(cls, model_instance: BaseLLM, prompt_messages: List[PromptMessage]): | |||
| # recalc max_tokens if sum(prompt_token + max_tokens) over model token limit | |||
| model_limited_tokens = model_instance.model_rules.max_tokens.max | |||
| max_tokens = model_instance.get_model_kwargs().max_tokens | |||
| if model_limited_tokens is None: | |||
| return | |||
| if max_tokens is None: | |||
| max_tokens = 0 | |||
| prompt_tokens = model_instance.get_num_tokens(prompt_messages) | |||
| if prompt_tokens + max_tokens > model_limited_tokens: | |||
| max_tokens = max(model_limited_tokens - prompt_tokens, 16) | |||
| # update model instance max tokens | |||
| model_kwargs = model_instance.get_model_kwargs() | |||
| model_kwargs.max_tokens = max_tokens | |||
| model_instance.set_model_kwargs(model_kwargs) | |||
| @@ -1,517 +0,0 @@ | |||
| import json | |||
| import time | |||
| from typing import Optional, Union, List | |||
| from core.callback_handler.entity.agent_loop import AgentLoop | |||
| from core.callback_handler.entity.dataset_query import DatasetQueryObj | |||
| from core.callback_handler.entity.llm_message import LLMMessage | |||
| from core.callback_handler.entity.chain_result import ChainResult | |||
| from core.file.file_obj import FileObj | |||
| from core.model_providers.model_factory import ModelFactory | |||
| from core.model_providers.models.entity.message import to_prompt_messages, MessageType, PromptMessageFile | |||
| from core.model_providers.models.llm.base import BaseLLM | |||
| from core.prompt.prompt_builder import PromptBuilder | |||
| from core.prompt.prompt_template import PromptTemplateParser | |||
| from events.message_event import message_was_created | |||
| from extensions.ext_database import db | |||
| from extensions.ext_redis import redis_client | |||
| from models.dataset import DatasetQuery | |||
| from models.model import AppModelConfig, Conversation, Account, Message, EndUser, App, MessageAgentThought, \ | |||
| MessageChain, DatasetRetrieverResource, MessageFile | |||
| class ConversationMessageTask: | |||
| def __init__(self, task_id: str, app: App, app_model_config: AppModelConfig, user: Account, | |||
| inputs: dict, query: str, files: List[FileObj], streaming: bool, | |||
| model_instance: BaseLLM, conversation: Optional[Conversation] = None, is_override: bool = False, | |||
| auto_generate_name: bool = True): | |||
| self.start_at = time.perf_counter() | |||
| self.task_id = task_id | |||
| self.app = app | |||
| self.tenant_id = app.tenant_id | |||
| self.app_model_config = app_model_config | |||
| self.is_override = is_override | |||
| self.user = user | |||
| self.inputs = inputs | |||
| self.query = query | |||
| self.files = files | |||
| self.streaming = streaming | |||
| self.conversation = conversation | |||
| self.is_new_conversation = False | |||
| self.model_instance = model_instance | |||
| self.message = None | |||
| self.retriever_resource = None | |||
| self.auto_generate_name = auto_generate_name | |||
| self.model_dict = self.app_model_config.model_dict | |||
| self.provider_name = self.model_dict.get('provider') | |||
| self.model_name = self.model_dict.get('name') | |||
| self.mode = app.mode | |||
| self.init() | |||
| self._pub_handler = PubHandler( | |||
| user=self.user, | |||
| task_id=self.task_id, | |||
| message=self.message, | |||
| conversation=self.conversation, | |||
| chain_pub=False, # disabled currently | |||
| agent_thought_pub=True | |||
| ) | |||
| def init(self): | |||
| override_model_configs = None | |||
| if self.is_override: | |||
| override_model_configs = self.app_model_config.to_dict() | |||
| introduction = '' | |||
| system_instruction = '' | |||
| system_instruction_tokens = 0 | |||
| if self.mode == 'chat': | |||
| introduction = self.app_model_config.opening_statement | |||
| if introduction: | |||
| prompt_template = PromptTemplateParser(template=introduction) | |||
| prompt_inputs = {k: self.inputs[k] for k in prompt_template.variable_keys if k in self.inputs} | |||
| try: | |||
| introduction = prompt_template.format(prompt_inputs) | |||
| except KeyError: | |||
| pass | |||
| if self.app_model_config.pre_prompt: | |||
| system_message = PromptBuilder.to_system_message(self.app_model_config.pre_prompt, self.inputs) | |||
| system_instruction = system_message.content | |||
| model_instance = ModelFactory.get_text_generation_model( | |||
| tenant_id=self.tenant_id, | |||
| model_provider_name=self.provider_name, | |||
| model_name=self.model_name | |||
| ) | |||
| system_instruction_tokens = model_instance.get_num_tokens(to_prompt_messages([system_message])) | |||
| if not self.conversation: | |||
| self.is_new_conversation = True | |||
| self.conversation = Conversation( | |||
| app_id=self.app.id, | |||
| app_model_config_id=self.app_model_config.id, | |||
| model_provider=self.provider_name, | |||
| model_id=self.model_name, | |||
| override_model_configs=json.dumps(override_model_configs) if override_model_configs else None, | |||
| mode=self.mode, | |||
| name='New conversation', | |||
| inputs=self.inputs, | |||
| introduction=introduction, | |||
| system_instruction=system_instruction, | |||
| system_instruction_tokens=system_instruction_tokens, | |||
| status='normal', | |||
| from_source=('console' if isinstance(self.user, Account) else 'api'), | |||
| from_end_user_id=(self.user.id if isinstance(self.user, EndUser) else None), | |||
| from_account_id=(self.user.id if isinstance(self.user, Account) else None), | |||
| ) | |||
| db.session.add(self.conversation) | |||
| db.session.commit() | |||
| self.message = Message( | |||
| app_id=self.app.id, | |||
| model_provider=self.provider_name, | |||
| model_id=self.model_name, | |||
| override_model_configs=json.dumps(override_model_configs) if override_model_configs else None, | |||
| conversation_id=self.conversation.id, | |||
| inputs=self.inputs, | |||
| query=self.query, | |||
| message="", | |||
| message_tokens=0, | |||
| message_unit_price=0, | |||
| message_price_unit=0, | |||
| answer="", | |||
| answer_tokens=0, | |||
| answer_unit_price=0, | |||
| answer_price_unit=0, | |||
| provider_response_latency=0, | |||
| total_price=0, | |||
| currency=self.model_instance.get_currency(), | |||
| from_source=('console' if isinstance(self.user, Account) else 'api'), | |||
| from_end_user_id=(self.user.id if isinstance(self.user, EndUser) else None), | |||
| from_account_id=(self.user.id if isinstance(self.user, Account) else None), | |||
| agent_based=self.app_model_config.agent_mode_dict.get('enabled'), | |||
| ) | |||
| db.session.add(self.message) | |||
| db.session.commit() | |||
| for file in self.files: | |||
| message_file = MessageFile( | |||
| message_id=self.message.id, | |||
| type=file.type.value, | |||
| transfer_method=file.transfer_method.value, | |||
| url=file.url, | |||
| upload_file_id=file.upload_file_id, | |||
| created_by_role=('account' if isinstance(self.user, Account) else 'end_user'), | |||
| created_by=self.user.id | |||
| ) | |||
| db.session.add(message_file) | |||
| db.session.commit() | |||
| def append_message_text(self, text: str): | |||
| if text is not None: | |||
| self._pub_handler.pub_text(text) | |||
| def save_message(self, llm_message: LLMMessage, by_stopped: bool = False): | |||
| message_tokens = llm_message.prompt_tokens | |||
| answer_tokens = llm_message.completion_tokens | |||
| message_unit_price = self.model_instance.get_tokens_unit_price(MessageType.USER) | |||
| message_price_unit = self.model_instance.get_price_unit(MessageType.USER) | |||
| answer_unit_price = self.model_instance.get_tokens_unit_price(MessageType.ASSISTANT) | |||
| answer_price_unit = self.model_instance.get_price_unit(MessageType.ASSISTANT) | |||
| message_total_price = self.model_instance.calc_tokens_price(message_tokens, MessageType.USER) | |||
| answer_total_price = self.model_instance.calc_tokens_price(answer_tokens, MessageType.ASSISTANT) | |||
| total_price = message_total_price + answer_total_price | |||
| self.message.message = llm_message.prompt | |||
| self.message.message_tokens = message_tokens | |||
| self.message.message_unit_price = message_unit_price | |||
| self.message.message_price_unit = message_price_unit | |||
| self.message.answer = PromptTemplateParser.remove_template_variables( | |||
| llm_message.completion.strip()) if llm_message.completion else '' | |||
| self.message.answer_tokens = answer_tokens | |||
| self.message.answer_unit_price = answer_unit_price | |||
| self.message.answer_price_unit = answer_price_unit | |||
| self.message.provider_response_latency = time.perf_counter() - self.start_at | |||
| self.message.total_price = total_price | |||
| db.session.commit() | |||
| message_was_created.send( | |||
| self.message, | |||
| conversation=self.conversation, | |||
| is_first_message=self.is_new_conversation, | |||
| auto_generate_name=self.auto_generate_name | |||
| ) | |||
| if not by_stopped: | |||
| self.end() | |||
| def init_chain(self, chain_result: ChainResult): | |||
| message_chain = MessageChain( | |||
| message_id=self.message.id, | |||
| type=chain_result.type, | |||
| input=json.dumps(chain_result.prompt), | |||
| output='' | |||
| ) | |||
| db.session.add(message_chain) | |||
| db.session.commit() | |||
| return message_chain | |||
| def on_chain_end(self, message_chain: MessageChain, chain_result: ChainResult): | |||
| message_chain.output = json.dumps(chain_result.completion) | |||
| db.session.commit() | |||
| self._pub_handler.pub_chain(message_chain) | |||
| def on_agent_start(self, message_chain: MessageChain, agent_loop: AgentLoop) -> MessageAgentThought: | |||
| message_agent_thought = MessageAgentThought( | |||
| message_id=self.message.id, | |||
| message_chain_id=message_chain.id, | |||
| position=agent_loop.position, | |||
| thought=agent_loop.thought, | |||
| tool=agent_loop.tool_name, | |||
| tool_input=agent_loop.tool_input, | |||
| message=agent_loop.prompt, | |||
| message_price_unit=0, | |||
| answer=agent_loop.completion, | |||
| answer_price_unit=0, | |||
| created_by_role=('account' if isinstance(self.user, Account) else 'end_user'), | |||
| created_by=self.user.id | |||
| ) | |||
| db.session.add(message_agent_thought) | |||
| db.session.commit() | |||
| self._pub_handler.pub_agent_thought(message_agent_thought) | |||
| return message_agent_thought | |||
| def on_agent_end(self, message_agent_thought: MessageAgentThought, agent_model_instance: BaseLLM, | |||
| agent_loop: AgentLoop): | |||
| agent_message_unit_price = agent_model_instance.get_tokens_unit_price(MessageType.USER) | |||
| agent_message_price_unit = agent_model_instance.get_price_unit(MessageType.USER) | |||
| agent_answer_unit_price = agent_model_instance.get_tokens_unit_price(MessageType.ASSISTANT) | |||
| agent_answer_price_unit = agent_model_instance.get_price_unit(MessageType.ASSISTANT) | |||
| loop_message_tokens = agent_loop.prompt_tokens | |||
| loop_answer_tokens = agent_loop.completion_tokens | |||
| loop_message_total_price = agent_model_instance.calc_tokens_price(loop_message_tokens, MessageType.USER) | |||
| loop_answer_total_price = agent_model_instance.calc_tokens_price(loop_answer_tokens, MessageType.ASSISTANT) | |||
| loop_total_price = loop_message_total_price + loop_answer_total_price | |||
| message_agent_thought.observation = agent_loop.tool_output | |||
| message_agent_thought.tool_process_data = '' # currently not support | |||
| message_agent_thought.message_token = loop_message_tokens | |||
| message_agent_thought.message_unit_price = agent_message_unit_price | |||
| message_agent_thought.message_price_unit = agent_message_price_unit | |||
| message_agent_thought.answer_token = loop_answer_tokens | |||
| message_agent_thought.answer_unit_price = agent_answer_unit_price | |||
| message_agent_thought.answer_price_unit = agent_answer_price_unit | |||
| message_agent_thought.latency = agent_loop.latency | |||
| message_agent_thought.tokens = agent_loop.prompt_tokens + agent_loop.completion_tokens | |||
| message_agent_thought.total_price = loop_total_price | |||
| message_agent_thought.currency = agent_model_instance.get_currency() | |||
| db.session.commit() | |||
| def on_dataset_query_end(self, dataset_query_obj: DatasetQueryObj): | |||
| dataset_query = DatasetQuery( | |||
| dataset_id=dataset_query_obj.dataset_id, | |||
| content=dataset_query_obj.query, | |||
| source='app', | |||
| source_app_id=self.app.id, | |||
| created_by_role=('account' if isinstance(self.user, Account) else 'end_user'), | |||
| created_by=self.user.id | |||
| ) | |||
| db.session.add(dataset_query) | |||
| db.session.commit() | |||
| def on_dataset_query_finish(self, resource: List): | |||
| if resource and len(resource) > 0: | |||
| for item in resource: | |||
| dataset_retriever_resource = DatasetRetrieverResource( | |||
| message_id=self.message.id, | |||
| position=item.get('position'), | |||
| dataset_id=item.get('dataset_id'), | |||
| dataset_name=item.get('dataset_name'), | |||
| document_id=item.get('document_id'), | |||
| document_name=item.get('document_name'), | |||
| data_source_type=item.get('data_source_type'), | |||
| segment_id=item.get('segment_id'), | |||
| score=item.get('score') if 'score' in item else None, | |||
| hit_count=item.get('hit_count') if 'hit_count' else None, | |||
| word_count=item.get('word_count') if 'word_count' in item else None, | |||
| segment_position=item.get('segment_position') if 'segment_position' in item else None, | |||
| index_node_hash=item.get('index_node_hash') if 'index_node_hash' in item else None, | |||
| content=item.get('content'), | |||
| retriever_from=item.get('retriever_from'), | |||
| created_by=self.user.id | |||
| ) | |||
| db.session.add(dataset_retriever_resource) | |||
| db.session.commit() | |||
| self.retriever_resource = resource | |||
| def on_message_replace(self, text: str): | |||
| if text is not None: | |||
| self._pub_handler.pub_message_replace(text) | |||
| def message_end(self): | |||
| self._pub_handler.pub_message_end(self.retriever_resource) | |||
| def end(self): | |||
| self._pub_handler.pub_message_end(self.retriever_resource) | |||
| self._pub_handler.pub_end() | |||
| def annotation_end(self, text: str, annotation_id: str, annotation_author_name: str): | |||
| self._pub_handler.pub_annotation(text, annotation_id, annotation_author_name, self.start_at) | |||
| self._pub_handler.pub_end() | |||
| class PubHandler: | |||
| def __init__(self, user: Union[Account, EndUser], task_id: str, | |||
| message: Message, conversation: Conversation, | |||
| chain_pub: bool = False, agent_thought_pub: bool = False): | |||
| self._channel = PubHandler.generate_channel_name(user, task_id) | |||
| self._stopped_cache_key = PubHandler.generate_stopped_cache_key(user, task_id) | |||
| self._task_id = task_id | |||
| self._message = message | |||
| self._conversation = conversation | |||
| self._chain_pub = chain_pub | |||
| self._agent_thought_pub = agent_thought_pub | |||
| @classmethod | |||
| def generate_channel_name(cls, user: Union[Account, EndUser], task_id: str): | |||
| if not user: | |||
| raise ValueError("user is required") | |||
| user_str = 'account-' + str(user.id) if isinstance(user, Account) else 'end-user-' + str(user.id) | |||
| return "generate_result:{}-{}".format(user_str, task_id) | |||
| @classmethod | |||
| def generate_stopped_cache_key(cls, user: Union[Account, EndUser], task_id: str): | |||
| user_str = 'account-' + str(user.id) if isinstance(user, Account) else 'end-user-' + str(user.id) | |||
| return "generate_result_stopped:{}-{}".format(user_str, task_id) | |||
| def pub_text(self, text: str): | |||
| content = { | |||
| 'event': 'message', | |||
| 'data': { | |||
| 'task_id': self._task_id, | |||
| 'message_id': str(self._message.id), | |||
| 'text': text, | |||
| 'mode': self._conversation.mode, | |||
| 'conversation_id': str(self._conversation.id) | |||
| } | |||
| } | |||
| redis_client.publish(self._channel, json.dumps(content)) | |||
| if self._is_stopped(): | |||
| self.pub_end() | |||
| raise ConversationTaskStoppedException() | |||
| def pub_message_replace(self, text: str): | |||
| content = { | |||
| 'event': 'message_replace', | |||
| 'data': { | |||
| 'task_id': self._task_id, | |||
| 'message_id': str(self._message.id), | |||
| 'text': text, | |||
| 'mode': self._conversation.mode, | |||
| 'conversation_id': str(self._conversation.id) | |||
| } | |||
| } | |||
| redis_client.publish(self._channel, json.dumps(content)) | |||
| if self._is_stopped(): | |||
| self.pub_end() | |||
| raise ConversationTaskStoppedException() | |||
| def pub_chain(self, message_chain: MessageChain): | |||
| if self._chain_pub: | |||
| content = { | |||
| 'event': 'chain', | |||
| 'data': { | |||
| 'task_id': self._task_id, | |||
| 'message_id': self._message.id, | |||
| 'chain_id': message_chain.id, | |||
| 'type': message_chain.type, | |||
| 'input': json.loads(message_chain.input), | |||
| 'output': json.loads(message_chain.output), | |||
| 'mode': self._conversation.mode, | |||
| 'conversation_id': self._conversation.id | |||
| } | |||
| } | |||
| redis_client.publish(self._channel, json.dumps(content)) | |||
| if self._is_stopped(): | |||
| self.pub_end() | |||
| raise ConversationTaskStoppedException() | |||
| def pub_agent_thought(self, message_agent_thought: MessageAgentThought): | |||
| if self._agent_thought_pub: | |||
| content = { | |||
| 'event': 'agent_thought', | |||
| 'data': { | |||
| 'id': message_agent_thought.id, | |||
| 'task_id': self._task_id, | |||
| 'message_id': self._message.id, | |||
| 'chain_id': message_agent_thought.message_chain_id, | |||
| 'position': message_agent_thought.position, | |||
| 'thought': message_agent_thought.thought, | |||
| 'tool': message_agent_thought.tool, | |||
| 'tool_input': message_agent_thought.tool_input, | |||
| 'mode': self._conversation.mode, | |||
| 'conversation_id': self._conversation.id | |||
| } | |||
| } | |||
| redis_client.publish(self._channel, json.dumps(content)) | |||
| if self._is_stopped(): | |||
| self.pub_end() | |||
| raise ConversationTaskStoppedException() | |||
| def pub_message_end(self, retriever_resource: List): | |||
| content = { | |||
| 'event': 'message_end', | |||
| 'data': { | |||
| 'task_id': self._task_id, | |||
| 'message_id': self._message.id, | |||
| 'mode': self._conversation.mode, | |||
| 'conversation_id': self._conversation.id, | |||
| } | |||
| } | |||
| if retriever_resource: | |||
| content['data']['retriever_resources'] = retriever_resource | |||
| redis_client.publish(self._channel, json.dumps(content)) | |||
| if self._is_stopped(): | |||
| self.pub_end() | |||
| raise ConversationTaskStoppedException() | |||
| def pub_annotation(self, text: str, annotation_id: str, annotation_author_name: str, start_at: float): | |||
| content = { | |||
| 'event': 'annotation', | |||
| 'data': { | |||
| 'task_id': self._task_id, | |||
| 'message_id': self._message.id, | |||
| 'mode': self._conversation.mode, | |||
| 'conversation_id': self._conversation.id, | |||
| 'text': text, | |||
| 'annotation_id': annotation_id, | |||
| 'annotation_author_name': annotation_author_name | |||
| } | |||
| } | |||
| self._message.answer = text | |||
| self._message.provider_response_latency = time.perf_counter() - start_at | |||
| db.session.commit() | |||
| redis_client.publish(self._channel, json.dumps(content)) | |||
| if self._is_stopped(): | |||
| self.pub_end() | |||
| raise ConversationTaskStoppedException() | |||
| def pub_end(self): | |||
| content = { | |||
| 'event': 'end', | |||
| } | |||
| redis_client.publish(self._channel, json.dumps(content)) | |||
| @classmethod | |||
| def pub_error(cls, user: Union[Account, EndUser], task_id: str, e): | |||
| content = { | |||
| 'error': type(e).__name__, | |||
| 'description': e.description if getattr(e, 'description', None) is not None else str(e) | |||
| } | |||
| channel = cls.generate_channel_name(user, task_id) | |||
| redis_client.publish(channel, json.dumps(content)) | |||
| def _is_stopped(self): | |||
| return redis_client.get(self._stopped_cache_key) is not None | |||
| @classmethod | |||
| def ping(cls, user: Union[Account, EndUser], task_id: str): | |||
| content = { | |||
| 'event': 'ping' | |||
| } | |||
| channel = cls.generate_channel_name(user, task_id) | |||
| redis_client.publish(channel, json.dumps(content)) | |||
| @classmethod | |||
| def stop(cls, user: Union[Account, EndUser], task_id: str): | |||
| stopped_cache_key = cls.generate_stopped_cache_key(user, task_id) | |||
| redis_client.setex(stopped_cache_key, 600, 1) | |||
| class ConversationTaskStoppedException(Exception): | |||
| pass | |||
| class ConversationTaskInterruptException(Exception): | |||
| pass | |||
| @@ -1,9 +1,11 @@ | |||
| from typing import Any, Dict, Optional, Sequence | |||
| from typing import Any, Dict, Optional, Sequence, cast | |||
| from langchain.schema import Document | |||
| from sqlalchemy import func | |||
| from core.model_providers.model_factory import ModelFactory | |||
| from core.model_manager import ModelManager | |||
| from core.model_runtime.entities.model_entities import ModelType | |||
| from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel | |||
| from extensions.ext_database import db | |||
| from models.dataset import Dataset, DocumentSegment | |||
| @@ -69,10 +71,12 @@ class DatasetDocumentStore: | |||
| max_position = 0 | |||
| embedding_model = None | |||
| if self._dataset.indexing_technique == 'high_quality': | |||
| embedding_model = ModelFactory.get_embedding_model( | |||
| model_manager = ModelManager() | |||
| embedding_model = model_manager.get_model_instance( | |||
| tenant_id=self._dataset.tenant_id, | |||
| model_provider_name=self._dataset.embedding_model_provider, | |||
| model_name=self._dataset.embedding_model | |||
| provider=self._dataset.embedding_model_provider, | |||
| model_type=ModelType.TEXT_EMBEDDING, | |||
| model=self._dataset.embedding_model | |||
| ) | |||
| for doc in docs: | |||
| @@ -89,7 +93,16 @@ class DatasetDocumentStore: | |||
| ) | |||
| # calc embedding use tokens | |||
| tokens = embedding_model.get_num_tokens(doc.page_content) if embedding_model else 0 | |||
| if embedding_model: | |||
| model_type_instance = embedding_model.model_type_instance | |||
| model_type_instance = cast(TextEmbeddingModel, model_type_instance) | |||
| tokens = model_type_instance.get_num_tokens( | |||
| model=embedding_model.model, | |||
| credentials=embedding_model.credentials, | |||
| texts=[doc.page_content] | |||
| ) | |||
| else: | |||
| tokens = 0 | |||
| if not segment_document: | |||
| max_position += 1 | |||
| @@ -1,19 +1,22 @@ | |||
| import logging | |||
| from typing import List | |||
| from typing import List, Optional | |||
| import numpy as np | |||
| from langchain.embeddings.base import Embeddings | |||
| from sqlalchemy.exc import IntegrityError | |||
| from core.model_providers.models.embedding.base import BaseEmbedding | |||
| from core.model_manager import ModelInstance | |||
| from extensions.ext_database import db | |||
| from libs import helper | |||
| from models.dataset import Embedding | |||
| logger = logging.getLogger(__name__) | |||
| class CacheEmbedding(Embeddings): | |||
| def __init__(self, embeddings: BaseEmbedding): | |||
| self._embeddings = embeddings | |||
| def __init__(self, model_instance: ModelInstance, user: Optional[str] = None) -> None: | |||
| self._model_instance = model_instance | |||
| self._user = user | |||
| def embed_documents(self, texts: List[str]) -> List[List[float]]: | |||
| """Embed search docs.""" | |||
| @@ -22,7 +25,7 @@ class CacheEmbedding(Embeddings): | |||
| embedding_queue_indices = [] | |||
| for i, text in enumerate(texts): | |||
| hash = helper.generate_text_hash(text) | |||
| embedding = db.session.query(Embedding).filter_by(model_name=self._embeddings.name, hash=hash).first() | |||
| embedding = db.session.query(Embedding).filter_by(model_name=self._model_instance.model, hash=hash).first() | |||
| if embedding: | |||
| text_embeddings[i] = embedding.get_embedding() | |||
| else: | |||
| @@ -30,15 +33,21 @@ class CacheEmbedding(Embeddings): | |||
| if embedding_queue_indices: | |||
| try: | |||
| embedding_results = self._embeddings.client.embed_documents([texts[i] for i in embedding_queue_indices]) | |||
| embedding_result = self._model_instance.invoke_text_embedding( | |||
| texts=[texts[i] for i in embedding_queue_indices], | |||
| user=self._user | |||
| ) | |||
| embedding_results = embedding_result.embeddings | |||
| except Exception as ex: | |||
| raise self._embeddings.handle_exceptions(ex) | |||
| logger.error('Failed to embed documents: ', ex) | |||
| raise ex | |||
| for i, indice in enumerate(embedding_queue_indices): | |||
| hash = helper.generate_text_hash(texts[indice]) | |||
| try: | |||
| embedding = Embedding(model_name=self._embeddings.name, hash=hash) | |||
| embedding = Embedding(model_name=self._model_instance.model, hash=hash) | |||
| vector = embedding_results[i] | |||
| normalized_embedding = (vector / np.linalg.norm(vector)).tolist() | |||
| text_embeddings[indice] = normalized_embedding | |||
| @@ -58,18 +67,23 @@ class CacheEmbedding(Embeddings): | |||
| """Embed query text.""" | |||
| # use doc embedding cache or store if not exists | |||
| hash = helper.generate_text_hash(text) | |||
| embedding = db.session.query(Embedding).filter_by(model_name=self._embeddings.name, hash=hash).first() | |||
| embedding = db.session.query(Embedding).filter_by(model_name=self._model_instance.model, hash=hash).first() | |||
| if embedding: | |||
| return embedding.get_embedding() | |||
| try: | |||
| embedding_results = self._embeddings.client.embed_query(text) | |||
| embedding_result = self._model_instance.invoke_text_embedding( | |||
| texts=[text], | |||
| user=self._user | |||
| ) | |||
| embedding_results = embedding_result.embeddings[0] | |||
| embedding_results = (embedding_results / np.linalg.norm(embedding_results)).tolist() | |||
| except Exception as ex: | |||
| raise self._embeddings.handle_exceptions(ex) | |||
| raise ex | |||
| try: | |||
| embedding = Embedding(model_name=self._embeddings.name, hash=hash) | |||
| embedding = Embedding(model_name=self._model_instance.model, hash=hash) | |||
| embedding.set_embedding(embedding_results) | |||
| db.session.add(embedding) | |||
| db.session.commit() | |||
| @@ -79,4 +93,3 @@ class CacheEmbedding(Embeddings): | |||
| logging.exception('Failed to add embedding to db') | |||
| return embedding_results | |||
| @@ -0,0 +1,265 @@ | |||
| from enum import Enum | |||
| from typing import Optional, Any, cast | |||
| from pydantic import BaseModel | |||
| from core.entities.provider_configuration import ProviderModelBundle | |||
| from core.file.file_obj import FileObj | |||
| from core.model_runtime.entities.message_entities import PromptMessageRole | |||
| from core.model_runtime.entities.model_entities import AIModelEntity | |||
| class ModelConfigEntity(BaseModel): | |||
| """ | |||
| Model Config Entity. | |||
| """ | |||
| provider: str | |||
| model: str | |||
| model_schema: AIModelEntity | |||
| mode: str | |||
| provider_model_bundle: ProviderModelBundle | |||
| credentials: dict[str, Any] = {} | |||
| parameters: dict[str, Any] = {} | |||
| stop: list[str] = [] | |||
| class AdvancedChatMessageEntity(BaseModel): | |||
| """ | |||
| Advanced Chat Message Entity. | |||
| """ | |||
| text: str | |||
| role: PromptMessageRole | |||
| class AdvancedChatPromptTemplateEntity(BaseModel): | |||
| """ | |||
| Advanced Chat Prompt Template Entity. | |||
| """ | |||
| messages: list[AdvancedChatMessageEntity] | |||
| class AdvancedCompletionPromptTemplateEntity(BaseModel): | |||
| """ | |||
| Advanced Completion Prompt Template Entity. | |||
| """ | |||
| class RolePrefixEntity(BaseModel): | |||
| """ | |||
| Role Prefix Entity. | |||
| """ | |||
| user: str | |||
| assistant: str | |||
| prompt: str | |||
| role_prefix: Optional[RolePrefixEntity] = None | |||
| class PromptTemplateEntity(BaseModel): | |||
| """ | |||
| Prompt Template Entity. | |||
| """ | |||
| class PromptType(Enum): | |||
| """ | |||
| Prompt Type. | |||
| 'simple', 'advanced' | |||
| """ | |||
| SIMPLE = 'simple' | |||
| ADVANCED = 'advanced' | |||
| @classmethod | |||
| def value_of(cls, value: str) -> 'PromptType': | |||
| """ | |||
| Get value of given mode. | |||
| :param value: mode value | |||
| :return: mode | |||
| """ | |||
| for mode in cls: | |||
| if mode.value == value: | |||
| return mode | |||
| raise ValueError(f'invalid prompt type value {value}') | |||
| prompt_type: PromptType | |||
| simple_prompt_template: Optional[str] = None | |||
| advanced_chat_prompt_template: Optional[AdvancedChatPromptTemplateEntity] = None | |||
| advanced_completion_prompt_template: Optional[AdvancedCompletionPromptTemplateEntity] = None | |||
| class ExternalDataVariableEntity(BaseModel): | |||
| """ | |||
| External Data Variable Entity. | |||
| """ | |||
| variable: str | |||
| type: str | |||
| config: dict[str, Any] = {} | |||
| class DatasetRetrieveConfigEntity(BaseModel): | |||
| """ | |||
| Dataset Retrieve Config Entity. | |||
| """ | |||
| class RetrieveStrategy(Enum): | |||
| """ | |||
| Dataset Retrieve Strategy. | |||
| 'single' or 'multiple' | |||
| """ | |||
| SINGLE = 'single' | |||
| MULTIPLE = 'multiple' | |||
| @classmethod | |||
| def value_of(cls, value: str) -> 'RetrieveStrategy': | |||
| """ | |||
| Get value of given mode. | |||
| :param value: mode value | |||
| :return: mode | |||
| """ | |||
| for mode in cls: | |||
| if mode.value == value: | |||
| return mode | |||
| raise ValueError(f'invalid retrieve strategy value {value}') | |||
| query_variable: Optional[str] = None # Only when app mode is completion | |||
| retrieve_strategy: RetrieveStrategy | |||
| single_strategy: Optional[str] = None # for temp | |||
| top_k: Optional[int] = None | |||
| score_threshold: Optional[float] = None | |||
| reranking_model: Optional[dict] = None | |||
| class DatasetEntity(BaseModel): | |||
| """ | |||
| Dataset Config Entity. | |||
| """ | |||
| dataset_ids: list[str] | |||
| retrieve_config: DatasetRetrieveConfigEntity | |||
| class SensitiveWordAvoidanceEntity(BaseModel): | |||
| """ | |||
| Sensitive Word Avoidance Entity. | |||
| """ | |||
| type: str | |||
| config: dict[str, Any] = {} | |||
| class FileUploadEntity(BaseModel): | |||
| """ | |||
| File Upload Entity. | |||
| """ | |||
| image_config: Optional[dict[str, Any]] = None | |||
| class AgentToolEntity(BaseModel): | |||
| """ | |||
| Agent Tool Entity. | |||
| """ | |||
| tool_id: str | |||
| config: dict[str, Any] = {} | |||
| class AgentEntity(BaseModel): | |||
| """ | |||
| Agent Entity. | |||
| """ | |||
| class Strategy(Enum): | |||
| """ | |||
| Agent Strategy. | |||
| """ | |||
| CHAIN_OF_THOUGHT = 'chain-of-thought' | |||
| FUNCTION_CALLING = 'function-calling' | |||
| provider: str | |||
| model: str | |||
| strategy: Strategy | |||
| tools: list[AgentToolEntity] = [] | |||
| class AppOrchestrationConfigEntity(BaseModel): | |||
| """ | |||
| App Orchestration Config Entity. | |||
| """ | |||
| model_config: ModelConfigEntity | |||
| prompt_template: PromptTemplateEntity | |||
| external_data_variables: list[ExternalDataVariableEntity] = [] | |||
| agent: Optional[AgentEntity] = None | |||
| # features | |||
| dataset: Optional[DatasetEntity] = None | |||
| file_upload: Optional[FileUploadEntity] = None | |||
| opening_statement: Optional[str] = None | |||
| suggested_questions_after_answer: bool = False | |||
| show_retrieve_source: bool = False | |||
| more_like_this: bool = False | |||
| speech_to_text: bool = False | |||
| sensitive_word_avoidance: Optional[SensitiveWordAvoidanceEntity] = None | |||
| class InvokeFrom(Enum): | |||
| """ | |||
| Invoke From. | |||
| """ | |||
| SERVICE_API = 'service-api' | |||
| WEB_APP = 'web-app' | |||
| EXPLORE = 'explore' | |||
| DEBUGGER = 'debugger' | |||
| @classmethod | |||
| def value_of(cls, value: str) -> 'InvokeFrom': | |||
| """ | |||
| Get value of given mode. | |||
| :param value: mode value | |||
| :return: mode | |||
| """ | |||
| for mode in cls: | |||
| if mode.value == value: | |||
| return mode | |||
| raise ValueError(f'invalid invoke from value {value}') | |||
| def to_source(self) -> str: | |||
| """ | |||
| Get source of invoke from. | |||
| :return: source | |||
| """ | |||
| if self == InvokeFrom.WEB_APP: | |||
| return 'web_app' | |||
| elif self == InvokeFrom.DEBUGGER: | |||
| return 'dev' | |||
| elif self == InvokeFrom.EXPLORE: | |||
| return 'explore_app' | |||
| elif self == InvokeFrom.SERVICE_API: | |||
| return 'api' | |||
| return 'dev' | |||
| class ApplicationGenerateEntity(BaseModel): | |||
| """ | |||
| Application Generate Entity. | |||
| """ | |||
| task_id: str | |||
| tenant_id: str | |||
| app_id: str | |||
| app_model_config_id: str | |||
| # for save | |||
| app_model_config_dict: dict | |||
| app_model_config_override: bool | |||
| # Converted from app_model_config to Entity object, or directly covered by external input | |||
| app_orchestration_config_entity: AppOrchestrationConfigEntity | |||
| conversation_id: Optional[str] = None | |||
| inputs: dict[str, str] | |||
| query: Optional[str] = None | |||
| files: list[FileObj] = [] | |||
| user_id: str | |||
| # extras | |||
| stream: bool | |||
| invoke_from: InvokeFrom | |||
| # extra parameters, like: auto_generate_conversation_name | |||
| extras: dict[str, Any] = {} | |||
| @@ -0,0 +1,128 @@ | |||
| import enum | |||
| from typing import Any, cast | |||
| from langchain.schema import HumanMessage, AIMessage, SystemMessage, BaseMessage, FunctionMessage | |||
| from pydantic import BaseModel | |||
| from core.model_runtime.entities.message_entities import PromptMessage, UserPromptMessage, TextPromptMessageContent, \ | |||
| ImagePromptMessageContent, AssistantPromptMessage, SystemPromptMessage, ToolPromptMessage | |||
| class PromptMessageFileType(enum.Enum): | |||
| IMAGE = 'image' | |||
| @staticmethod | |||
| def value_of(value): | |||
| for member in PromptMessageFileType: | |||
| if member.value == value: | |||
| return member | |||
| raise ValueError(f"No matching enum found for value '{value}'") | |||
| class PromptMessageFile(BaseModel): | |||
| type: PromptMessageFileType | |||
| data: Any | |||
| class ImagePromptMessageFile(PromptMessageFile): | |||
| class DETAIL(enum.Enum): | |||
| LOW = 'low' | |||
| HIGH = 'high' | |||
| type: PromptMessageFileType = PromptMessageFileType.IMAGE | |||
| detail: DETAIL = DETAIL.LOW | |||
| class LCHumanMessageWithFiles(HumanMessage): | |||
| # content: Union[str, List[Union[str, Dict]]] | |||
| content: str | |||
| files: list[PromptMessageFile] | |||
| def lc_messages_to_prompt_messages(messages: list[BaseMessage]) -> list[PromptMessage]: | |||
| prompt_messages = [] | |||
| for message in messages: | |||
| if isinstance(message, HumanMessage): | |||
| if isinstance(message, LCHumanMessageWithFiles): | |||
| file_prompt_message_contents = [] | |||
| for file in message.files: | |||
| if file.type == PromptMessageFileType.IMAGE: | |||
| file = cast(ImagePromptMessageFile, file) | |||
| file_prompt_message_contents.append(ImagePromptMessageContent( | |||
| data=file.data, | |||
| detail=ImagePromptMessageContent.DETAIL.HIGH | |||
| if file.detail.value == "high" else ImagePromptMessageContent.DETAIL.LOW | |||
| )) | |||
| prompt_message_contents = [TextPromptMessageContent(data=message.content)] | |||
| prompt_message_contents.extend(file_prompt_message_contents) | |||
| prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) | |||
| else: | |||
| prompt_messages.append(UserPromptMessage(content=message.content)) | |||
| elif isinstance(message, AIMessage): | |||
| message_kwargs = { | |||
| 'content': message.content | |||
| } | |||
| if 'function_call' in message.additional_kwargs: | |||
| message_kwargs['tool_calls'] = [ | |||
| AssistantPromptMessage.ToolCall( | |||
| id=message.additional_kwargs['function_call']['id'], | |||
| type='function', | |||
| function=AssistantPromptMessage.ToolCall.ToolCallFunction( | |||
| name=message.additional_kwargs['function_call']['name'], | |||
| arguments=message.additional_kwargs['function_call']['arguments'] | |||
| ) | |||
| ) | |||
| ] | |||
| prompt_messages.append(AssistantPromptMessage(**message_kwargs)) | |||
| elif isinstance(message, SystemMessage): | |||
| prompt_messages.append(SystemPromptMessage(content=message.content)) | |||
| elif isinstance(message, FunctionMessage): | |||
| prompt_messages.append(ToolPromptMessage(content=message.content, tool_call_id=message.name)) | |||
| return prompt_messages | |||
| def prompt_messages_to_lc_messages(prompt_messages: list[PromptMessage]) -> list[BaseMessage]: | |||
| messages = [] | |||
| for prompt_message in prompt_messages: | |||
| if isinstance(prompt_message, UserPromptMessage): | |||
| if isinstance(prompt_message.content, str): | |||
| messages.append(HumanMessage(content=prompt_message.content)) | |||
| else: | |||
| message_contents = [] | |||
| for content in prompt_message.content: | |||
| if isinstance(content, TextPromptMessageContent): | |||
| message_contents.append(content.data) | |||
| elif isinstance(content, ImagePromptMessageContent): | |||
| message_contents.append({ | |||
| 'type': 'image', | |||
| 'data': content.data, | |||
| 'detail': content.detail.value | |||
| }) | |||
| messages.append(HumanMessage(content=message_contents)) | |||
| elif isinstance(prompt_message, AssistantPromptMessage): | |||
| message_kwargs = { | |||
| 'content': prompt_message.content | |||
| } | |||
| if prompt_message.tool_calls: | |||
| message_kwargs['additional_kwargs'] = { | |||
| 'function_call': { | |||
| 'id': prompt_message.tool_calls[0].id, | |||
| 'name': prompt_message.tool_calls[0].function.name, | |||
| 'arguments': prompt_message.tool_calls[0].function.arguments | |||
| } | |||
| } | |||
| messages.append(AIMessage(**message_kwargs)) | |||
| elif isinstance(prompt_message, SystemPromptMessage): | |||
| messages.append(SystemMessage(content=prompt_message.content)) | |||
| elif isinstance(prompt_message, ToolPromptMessage): | |||
| messages.append(FunctionMessage(name=prompt_message.tool_call_id, content=prompt_message.content)) | |||
| return messages | |||
| @@ -0,0 +1,71 @@ | |||
| from enum import Enum | |||
| from typing import Optional | |||
| from pydantic import BaseModel | |||
| from core.model_runtime.entities.common_entities import I18nObject | |||
| from core.model_runtime.entities.model_entities import ProviderModel, ModelType | |||
| from core.model_runtime.entities.provider_entities import SimpleProviderEntity, ProviderEntity | |||
| class ModelStatus(Enum): | |||
| """ | |||
| Enum class for model status. | |||
| """ | |||
| ACTIVE = "active" | |||
| NO_CONFIGURE = "no-configure" | |||
| QUOTA_EXCEEDED = "quota-exceeded" | |||
| NO_PERMISSION = "no-permission" | |||
| class SimpleModelProviderEntity(BaseModel): | |||
| """ | |||
| Simple provider. | |||
| """ | |||
| provider: str | |||
| label: I18nObject | |||
| icon_small: Optional[I18nObject] = None | |||
| icon_large: Optional[I18nObject] = None | |||
| supported_model_types: list[ModelType] | |||
| def __init__(self, provider_entity: ProviderEntity) -> None: | |||
| """ | |||
| Init simple provider. | |||
| :param provider_entity: provider entity | |||
| """ | |||
| super().__init__( | |||
| provider=provider_entity.provider, | |||
| label=provider_entity.label, | |||
| icon_small=provider_entity.icon_small, | |||
| icon_large=provider_entity.icon_large, | |||
| supported_model_types=provider_entity.supported_model_types | |||
| ) | |||
| class ModelWithProviderEntity(ProviderModel): | |||
| """ | |||
| Model with provider entity. | |||
| """ | |||
| provider: SimpleModelProviderEntity | |||
| status: ModelStatus | |||
| class DefaultModelProviderEntity(BaseModel): | |||
| """ | |||
| Default model provider entity. | |||
| """ | |||
| provider: str | |||
| label: I18nObject | |||
| icon_small: Optional[I18nObject] = None | |||
| icon_large: Optional[I18nObject] = None | |||
| supported_model_types: list[ModelType] | |||
| class DefaultModelEntity(BaseModel): | |||
| """ | |||
| Default model entity. | |||
| """ | |||
| model: str | |||
| model_type: ModelType | |||
| provider: DefaultModelProviderEntity | |||
| @@ -0,0 +1,657 @@ | |||
| import datetime | |||
| import json | |||
| import time | |||
| from json import JSONDecodeError | |||
| from typing import Optional, List, Dict, Tuple, Iterator | |||
| from pydantic import BaseModel | |||
| from core.entities.model_entities import ModelWithProviderEntity, ModelStatus, SimpleModelProviderEntity | |||
| from core.entities.provider_entities import SystemConfiguration, CustomConfiguration, SystemConfigurationStatus | |||
| from core.helper import encrypter | |||
| from core.model_runtime.entities.model_entities import ModelType | |||
| from core.model_runtime.entities.provider_entities import ProviderEntity, CredentialFormSchema, FormType | |||
| from core.model_runtime.model_providers import model_provider_factory | |||
| from core.model_runtime.model_providers.__base.ai_model import AIModel | |||
| from core.model_runtime.model_providers.__base.model_provider import ModelProvider | |||
| from core.model_runtime.utils import encoders | |||
| from extensions.ext_database import db | |||
| from models.provider import ProviderType, Provider, ProviderModel, TenantPreferredModelProvider | |||
| class ProviderConfiguration(BaseModel): | |||
| """ | |||
| Model class for provider configuration. | |||
| """ | |||
| tenant_id: str | |||
| provider: ProviderEntity | |||
| preferred_provider_type: ProviderType | |||
| using_provider_type: ProviderType | |||
| system_configuration: SystemConfiguration | |||
| custom_configuration: CustomConfiguration | |||
| def get_current_credentials(self, model_type: ModelType, model: str) -> Optional[dict]: | |||
| """ | |||
| Get current credentials. | |||
| :param model_type: model type | |||
| :param model: model name | |||
| :return: | |||
| """ | |||
| if self.using_provider_type == ProviderType.SYSTEM: | |||
| return self.system_configuration.credentials | |||
| else: | |||
| if self.custom_configuration.models: | |||
| for model_configuration in self.custom_configuration.models: | |||
| if model_configuration.model_type == model_type and model_configuration.model == model: | |||
| return model_configuration.credentials | |||
| if self.custom_configuration.provider: | |||
| return self.custom_configuration.provider.credentials | |||
| else: | |||
| return None | |||
| def get_system_configuration_status(self) -> SystemConfigurationStatus: | |||
| """ | |||
| Get system configuration status. | |||
| :return: | |||
| """ | |||
| if self.system_configuration.enabled is False: | |||
| return SystemConfigurationStatus.UNSUPPORTED | |||
| current_quota_type = self.system_configuration.current_quota_type | |||
| current_quota_configuration = next( | |||
| (q for q in self.system_configuration.quota_configurations if q.quota_type == current_quota_type), | |||
| None | |||
| ) | |||
| return SystemConfigurationStatus.ACTIVE if current_quota_configuration.is_valid else \ | |||
| SystemConfigurationStatus.QUOTA_EXCEEDED | |||
| def is_custom_configuration_available(self) -> bool: | |||
| """ | |||
| Check custom configuration available. | |||
| :return: | |||
| """ | |||
| return (self.custom_configuration.provider is not None | |||
| or len(self.custom_configuration.models) > 0) | |||
| def get_custom_credentials(self, obfuscated: bool = False) -> Optional[dict]: | |||
| """ | |||
| Get custom credentials. | |||
| :param obfuscated: obfuscated secret data in credentials | |||
| :return: | |||
| """ | |||
| if self.custom_configuration.provider is None: | |||
| return None | |||
| credentials = self.custom_configuration.provider.credentials | |||
| if not obfuscated: | |||
| return credentials | |||
| # Obfuscate credentials | |||
| return self._obfuscated_credentials( | |||
| credentials=credentials, | |||
| credential_form_schemas=self.provider.provider_credential_schema.credential_form_schemas | |||
| if self.provider.provider_credential_schema else [] | |||
| ) | |||
| def custom_credentials_validate(self, credentials: dict) -> Tuple[Provider, dict]: | |||
| """ | |||
| Validate custom credentials. | |||
| :param credentials: provider credentials | |||
| :return: | |||
| """ | |||
| # get provider | |||
| provider_record = db.session.query(Provider) \ | |||
| .filter( | |||
| Provider.tenant_id == self.tenant_id, | |||
| Provider.provider_name == self.provider.provider, | |||
| Provider.provider_type == ProviderType.CUSTOM.value | |||
| ).first() | |||
| # Get provider credential secret variables | |||
| provider_credential_secret_variables = self._extract_secret_variables( | |||
| self.provider.provider_credential_schema.credential_form_schemas | |||
| if self.provider.provider_credential_schema else [] | |||
| ) | |||
| if provider_record: | |||
| try: | |||
| original_credentials = json.loads(provider_record.encrypted_config) if provider_record.encrypted_config else {} | |||
| except JSONDecodeError: | |||
| original_credentials = {} | |||
| # encrypt credentials | |||
| for key, value in credentials.items(): | |||
| if key in provider_credential_secret_variables: | |||
| # if send [__HIDDEN__] in secret input, it will be same as original value | |||
| if value == '[__HIDDEN__]' and key in original_credentials: | |||
| credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key]) | |||
| model_provider_factory.provider_credentials_validate( | |||
| self.provider.provider, | |||
| credentials | |||
| ) | |||
| for key, value in credentials.items(): | |||
| if key in provider_credential_secret_variables: | |||
| credentials[key] = encrypter.encrypt_token(self.tenant_id, value) | |||
| return provider_record, credentials | |||
| def add_or_update_custom_credentials(self, credentials: dict) -> None: | |||
| """ | |||
| Add or update custom provider credentials. | |||
| :param credentials: | |||
| :return: | |||
| """ | |||
| # validate custom provider config | |||
| provider_record, credentials = self.custom_credentials_validate(credentials) | |||
| # save provider | |||
| # Note: Do not switch the preferred provider, which allows users to use quotas first | |||
| if provider_record: | |||
| provider_record.encrypted_config = json.dumps(credentials) | |||
| provider_record.is_valid = True | |||
| provider_record.updated_at = datetime.datetime.utcnow() | |||
| db.session.commit() | |||
| else: | |||
| provider_record = Provider( | |||
| tenant_id=self.tenant_id, | |||
| provider_name=self.provider.provider, | |||
| provider_type=ProviderType.CUSTOM.value, | |||
| encrypted_config=json.dumps(credentials), | |||
| is_valid=True | |||
| ) | |||
| db.session.add(provider_record) | |||
| db.session.commit() | |||
| self.switch_preferred_provider_type(ProviderType.CUSTOM) | |||
| def delete_custom_credentials(self) -> None: | |||
| """ | |||
| Delete custom provider credentials. | |||
| :return: | |||
| """ | |||
| # get provider | |||
| provider_record = db.session.query(Provider) \ | |||
| .filter( | |||
| Provider.tenant_id == self.tenant_id, | |||
| Provider.provider_name == self.provider.provider, | |||
| Provider.provider_type == ProviderType.CUSTOM.value | |||
| ).first() | |||
| # delete provider | |||
| if provider_record: | |||
| self.switch_preferred_provider_type(ProviderType.SYSTEM) | |||
| db.session.delete(provider_record) | |||
| db.session.commit() | |||
| def get_custom_model_credentials(self, model_type: ModelType, model: str, obfuscated: bool = False) \ | |||
| -> Optional[dict]: | |||
| """ | |||
| Get custom model credentials. | |||
| :param model_type: model type | |||
| :param model: model name | |||
| :param obfuscated: obfuscated secret data in credentials | |||
| :return: | |||
| """ | |||
| if not self.custom_configuration.models: | |||
| return None | |||
| for model_configuration in self.custom_configuration.models: | |||
| if model_configuration.model_type == model_type and model_configuration.model == model: | |||
| credentials = model_configuration.credentials | |||
| if not obfuscated: | |||
| return credentials | |||
| # Obfuscate credentials | |||
| return self._obfuscated_credentials( | |||
| credentials=credentials, | |||
| credential_form_schemas=self.provider.model_credential_schema.credential_form_schemas | |||
| if self.provider.model_credential_schema else [] | |||
| ) | |||
| return None | |||
| def custom_model_credentials_validate(self, model_type: ModelType, model: str, credentials: dict) \ | |||
| -> Tuple[ProviderModel, dict]: | |||
| """ | |||
| Validate custom model credentials. | |||
| :param model_type: model type | |||
| :param model: model name | |||
| :param credentials: model credentials | |||
| :return: | |||
| """ | |||
| # get provider model | |||
| provider_model_record = db.session.query(ProviderModel) \ | |||
| .filter( | |||
| ProviderModel.tenant_id == self.tenant_id, | |||
| ProviderModel.provider_name == self.provider.provider, | |||
| ProviderModel.model_name == model, | |||
| ProviderModel.model_type == model_type.to_origin_model_type() | |||
| ).first() | |||
| # Get provider credential secret variables | |||
| provider_credential_secret_variables = self._extract_secret_variables( | |||
| self.provider.model_credential_schema.credential_form_schemas | |||
| if self.provider.model_credential_schema else [] | |||
| ) | |||
| if provider_model_record: | |||
| try: | |||
| original_credentials = json.loads(provider_model_record.encrypted_config) if provider_model_record.encrypted_config else {} | |||
| except JSONDecodeError: | |||
| original_credentials = {} | |||
| # decrypt credentials | |||
| for key, value in credentials.items(): | |||
| if key in provider_credential_secret_variables: | |||
| # if send [__HIDDEN__] in secret input, it will be same as original value | |||
| if value == '[__HIDDEN__]' and key in original_credentials: | |||
| credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key]) | |||
| model_provider_factory.model_credentials_validate( | |||
| provider=self.provider.provider, | |||
| model_type=model_type, | |||
| model=model, | |||
| credentials=credentials | |||
| ) | |||
| model_schema = ( | |||
| model_provider_factory.get_provider_instance(self.provider.provider) | |||
| .get_model_instance(model_type)._get_customizable_model_schema( | |||
| model=model, | |||
| credentials=credentials | |||
| ) | |||
| ) | |||
| if model_schema: | |||
| credentials['schema'] = json.dumps(encoders.jsonable_encoder(model_schema)) | |||
| for key, value in credentials.items(): | |||
| if key in provider_credential_secret_variables: | |||
| credentials[key] = encrypter.encrypt_token(self.tenant_id, value) | |||
| return provider_model_record, credentials | |||
| def add_or_update_custom_model_credentials(self, model_type: ModelType, model: str, credentials: dict) -> None: | |||
| """ | |||
| Add or update custom model credentials. | |||
| :param model_type: model type | |||
| :param model: model name | |||
| :param credentials: model credentials | |||
| :return: | |||
| """ | |||
| # validate custom model config | |||
| provider_model_record, credentials = self.custom_model_credentials_validate(model_type, model, credentials) | |||
| # save provider model | |||
| # Note: Do not switch the preferred provider, which allows users to use quotas first | |||
| if provider_model_record: | |||
| provider_model_record.encrypted_config = json.dumps(credentials) | |||
| provider_model_record.is_valid = True | |||
| provider_model_record.updated_at = datetime.datetime.utcnow() | |||
| db.session.commit() | |||
| else: | |||
| provider_model_record = ProviderModel( | |||
| tenant_id=self.tenant_id, | |||
| provider_name=self.provider.provider, | |||
| model_name=model, | |||
| model_type=model_type.to_origin_model_type(), | |||
| encrypted_config=json.dumps(credentials), | |||
| is_valid=True | |||
| ) | |||
| db.session.add(provider_model_record) | |||
| db.session.commit() | |||
| def delete_custom_model_credentials(self, model_type: ModelType, model: str) -> None: | |||
| """ | |||
| Delete custom model credentials. | |||
| :param model_type: model type | |||
| :param model: model name | |||
| :return: | |||
| """ | |||
| # get provider model | |||
| provider_model_record = db.session.query(ProviderModel) \ | |||
| .filter( | |||
| ProviderModel.tenant_id == self.tenant_id, | |||
| ProviderModel.provider_name == self.provider.provider, | |||
| ProviderModel.model_name == model, | |||
| ProviderModel.model_type == model_type.to_origin_model_type() | |||
| ).first() | |||
| # delete provider model | |||
| if provider_model_record: | |||
| db.session.delete(provider_model_record) | |||
| db.session.commit() | |||
| def get_provider_instance(self) -> ModelProvider: | |||
| """ | |||
| Get provider instance. | |||
| :return: | |||
| """ | |||
| return model_provider_factory.get_provider_instance(self.provider.provider) | |||
| def get_model_type_instance(self, model_type: ModelType) -> AIModel: | |||
| """ | |||
| Get current model type instance. | |||
| :param model_type: model type | |||
| :return: | |||
| """ | |||
| # Get provider instance | |||
| provider_instance = self.get_provider_instance() | |||
| # Get model instance of LLM | |||
| return provider_instance.get_model_instance(model_type) | |||
| def switch_preferred_provider_type(self, provider_type: ProviderType) -> None: | |||
| """ | |||
| Switch preferred provider type. | |||
| :param provider_type: | |||
| :return: | |||
| """ | |||
| if provider_type == self.preferred_provider_type: | |||
| return | |||
| if provider_type == ProviderType.SYSTEM and not self.system_configuration.enabled: | |||
| return | |||
| # get preferred provider | |||
| preferred_model_provider = db.session.query(TenantPreferredModelProvider) \ | |||
| .filter( | |||
| TenantPreferredModelProvider.tenant_id == self.tenant_id, | |||
| TenantPreferredModelProvider.provider_name == self.provider.provider | |||
| ).first() | |||
| if preferred_model_provider: | |||
| preferred_model_provider.preferred_provider_type = provider_type.value | |||
| else: | |||
| preferred_model_provider = TenantPreferredModelProvider( | |||
| tenant_id=self.tenant_id, | |||
| provider_name=self.provider.provider, | |||
| preferred_provider_type=provider_type.value | |||
| ) | |||
| db.session.add(preferred_model_provider) | |||
| db.session.commit() | |||
| def _extract_secret_variables(self, credential_form_schemas: list[CredentialFormSchema]) -> list[str]: | |||
| """ | |||
| Extract secret input form variables. | |||
| :param credential_form_schemas: | |||
| :return: | |||
| """ | |||
| secret_input_form_variables = [] | |||
| for credential_form_schema in credential_form_schemas: | |||
| if credential_form_schema.type == FormType.SECRET_INPUT: | |||
| secret_input_form_variables.append(credential_form_schema.variable) | |||
| return secret_input_form_variables | |||
| def _obfuscated_credentials(self, credentials: dict, credential_form_schemas: list[CredentialFormSchema]) -> dict: | |||
| """ | |||
| Obfuscated credentials. | |||
| :param credentials: credentials | |||
| :param credential_form_schemas: credential form schemas | |||
| :return: | |||
| """ | |||
| # Get provider credential secret variables | |||
| credential_secret_variables = self._extract_secret_variables( | |||
| credential_form_schemas | |||
| ) | |||
| # Obfuscate provider credentials | |||
| copy_credentials = credentials.copy() | |||
| for key, value in copy_credentials.items(): | |||
| if key in credential_secret_variables: | |||
| copy_credentials[key] = encrypter.obfuscated_token(value) | |||
| return copy_credentials | |||
| def get_provider_model(self, model_type: ModelType, | |||
| model: str, | |||
| only_active: bool = False) -> Optional[ModelWithProviderEntity]: | |||
| """ | |||
| Get provider model. | |||
| :param model_type: model type | |||
| :param model: model name | |||
| :param only_active: return active model only | |||
| :return: | |||
| """ | |||
| provider_models = self.get_provider_models(model_type, only_active) | |||
| for provider_model in provider_models: | |||
| if provider_model.model == model: | |||
| return provider_model | |||
| return None | |||
| def get_provider_models(self, model_type: Optional[ModelType] = None, | |||
| only_active: bool = False) -> list[ModelWithProviderEntity]: | |||
| """ | |||
| Get provider models. | |||
| :param model_type: model type | |||
| :param only_active: only active models | |||
| :return: | |||
| """ | |||
| provider_instance = self.get_provider_instance() | |||
| model_types = [] | |||
| if model_type: | |||
| model_types.append(model_type) | |||
| else: | |||
| model_types = provider_instance.get_provider_schema().supported_model_types | |||
| if self.using_provider_type == ProviderType.SYSTEM: | |||
| provider_models = self._get_system_provider_models( | |||
| model_types=model_types, | |||
| provider_instance=provider_instance | |||
| ) | |||
| else: | |||
| provider_models = self._get_custom_provider_models( | |||
| model_types=model_types, | |||
| provider_instance=provider_instance | |||
| ) | |||
| if only_active: | |||
| provider_models = [m for m in provider_models if m.status == ModelStatus.ACTIVE] | |||
| # resort provider_models | |||
| return sorted(provider_models, key=lambda x: x.model_type.value) | |||
| def _get_system_provider_models(self, | |||
| model_types: list[ModelType], | |||
| provider_instance: ModelProvider) -> list[ModelWithProviderEntity]: | |||
| """ | |||
| Get system provider models. | |||
| :param model_types: model types | |||
| :param provider_instance: provider instance | |||
| :return: | |||
| """ | |||
| provider_models = [] | |||
| for model_type in model_types: | |||
| provider_models.extend( | |||
| [ | |||
| ModelWithProviderEntity( | |||
| **m.dict(), | |||
| provider=SimpleModelProviderEntity(self.provider), | |||
| status=ModelStatus.ACTIVE | |||
| ) | |||
| for m in provider_instance.models(model_type) | |||
| ] | |||
| ) | |||
| for quota_configuration in self.system_configuration.quota_configurations: | |||
| if self.system_configuration.current_quota_type != quota_configuration.quota_type: | |||
| continue | |||
| restrict_llms = quota_configuration.restrict_llms | |||
| if not restrict_llms: | |||
| break | |||
| # if llm name not in restricted llm list, remove it | |||
| for m in provider_models: | |||
| if m.model_type == ModelType.LLM and m.model not in restrict_llms: | |||
| m.status = ModelStatus.NO_PERMISSION | |||
| elif not quota_configuration.is_valid: | |||
| m.status = ModelStatus.QUOTA_EXCEEDED | |||
| return provider_models | |||
| def _get_custom_provider_models(self, | |||
| model_types: list[ModelType], | |||
| provider_instance: ModelProvider) -> list[ModelWithProviderEntity]: | |||
| """ | |||
| Get custom provider models. | |||
| :param model_types: model types | |||
| :param provider_instance: provider instance | |||
| :return: | |||
| """ | |||
| provider_models = [] | |||
| credentials = None | |||
| if self.custom_configuration.provider: | |||
| credentials = self.custom_configuration.provider.credentials | |||
| for model_type in model_types: | |||
| if model_type not in self.provider.supported_model_types: | |||
| continue | |||
| models = provider_instance.models(model_type) | |||
| for m in models: | |||
| provider_models.append( | |||
| ModelWithProviderEntity( | |||
| **m.dict(), | |||
| provider=SimpleModelProviderEntity(self.provider), | |||
| status=ModelStatus.ACTIVE if credentials else ModelStatus.NO_CONFIGURE | |||
| ) | |||
| ) | |||
| # custom models | |||
| for model_configuration in self.custom_configuration.models: | |||
| if model_configuration.model_type not in model_types: | |||
| continue | |||
| custom_model_schema = ( | |||
| provider_instance.get_model_instance(model_configuration.model_type) | |||
| .get_customizable_model_schema_from_credentials( | |||
| model_configuration.model, | |||
| model_configuration.credentials | |||
| ) | |||
| ) | |||
| if not custom_model_schema: | |||
| continue | |||
| provider_models.append( | |||
| ModelWithProviderEntity( | |||
| **custom_model_schema.dict(), | |||
| provider=SimpleModelProviderEntity(self.provider), | |||
| status=ModelStatus.ACTIVE | |||
| ) | |||
| ) | |||
| return provider_models | |||
| class ProviderConfigurations(BaseModel): | |||
| """ | |||
| Model class for provider configuration dict. | |||
| """ | |||
| tenant_id: str | |||
| configurations: Dict[str, ProviderConfiguration] = {} | |||
| def __init__(self, tenant_id: str): | |||
| super().__init__(tenant_id=tenant_id) | |||
| def get_models(self, | |||
| provider: Optional[str] = None, | |||
| model_type: Optional[ModelType] = None, | |||
| only_active: bool = False) \ | |||
| -> list[ModelWithProviderEntity]: | |||
| """ | |||
| Get available models. | |||
| If preferred provider type is `system`: | |||
| Get the current **system mode** if provider supported, | |||
| if all system modes are not available (no quota), it is considered to be the **custom credential mode**. | |||
| If there is no model configured in custom mode, it is treated as no_configure. | |||
| system > custom > no_configure | |||
| If preferred provider type is `custom`: | |||
| If custom credentials are configured, it is treated as custom mode. | |||
| Otherwise, get the current **system mode** if supported, | |||
| If all system modes are not available (no quota), it is treated as no_configure. | |||
| custom > system > no_configure | |||
| If real mode is `system`, use system credentials to get models, | |||
| paid quotas > provider free quotas > system free quotas | |||
| include pre-defined models (exclude GPT-4, status marked as `no_permission`). | |||
| If real mode is `custom`, use workspace custom credentials to get models, | |||
| include pre-defined models, custom models(manual append). | |||
| If real mode is `no_configure`, only return pre-defined models from `model runtime`. | |||
| (model status marked as `no_configure` if preferred provider type is `custom` otherwise `quota_exceeded`) | |||
| model status marked as `active` is available. | |||
| :param provider: provider name | |||
| :param model_type: model type | |||
| :param only_active: only active models | |||
| :return: | |||
| """ | |||
| all_models = [] | |||
| for provider_configuration in self.values(): | |||
| if provider and provider_configuration.provider.provider != provider: | |||
| continue | |||
| all_models.extend(provider_configuration.get_provider_models(model_type, only_active)) | |||
| return all_models | |||
| def to_list(self) -> List[ProviderConfiguration]: | |||
| """ | |||
| Convert to list. | |||
| :return: | |||
| """ | |||
| return list(self.values()) | |||
| def __getitem__(self, key): | |||
| return self.configurations[key] | |||
| def __setitem__(self, key, value): | |||
| self.configurations[key] = value | |||
| def __iter__(self): | |||
| return iter(self.configurations) | |||
| def values(self) -> Iterator[ProviderConfiguration]: | |||
| return self.configurations.values() | |||
| def get(self, key, default=None): | |||
| return self.configurations.get(key, default) | |||
| class ProviderModelBundle(BaseModel): | |||
| """ | |||
| Provider model bundle. | |||
| """ | |||
| configuration: ProviderConfiguration | |||
| provider_instance: ModelProvider | |||
| model_type_instance: AIModel | |||
| class Config: | |||
| """Configuration for this pydantic object.""" | |||
| arbitrary_types_allowed = True | |||
| @@ -0,0 +1,67 @@ | |||
| from enum import Enum | |||
| from typing import Optional | |||
| from pydantic import BaseModel | |||
| from core.model_runtime.entities.model_entities import ModelType | |||
| from models.provider import ProviderQuotaType | |||
| class QuotaUnit(Enum): | |||
| TIMES = 'times' | |||
| TOKENS = 'tokens' | |||
| class SystemConfigurationStatus(Enum): | |||
| """ | |||
| Enum class for system configuration status. | |||
| """ | |||
| ACTIVE = 'active' | |||
| QUOTA_EXCEEDED = 'quota-exceeded' | |||
| UNSUPPORTED = 'unsupported' | |||
| class QuotaConfiguration(BaseModel): | |||
| """ | |||
| Model class for provider quota configuration. | |||
| """ | |||
| quota_type: ProviderQuotaType | |||
| quota_unit: QuotaUnit | |||
| quota_limit: int | |||
| quota_used: int | |||
| is_valid: bool | |||
| restrict_llms: list[str] = [] | |||
| class SystemConfiguration(BaseModel): | |||
| """ | |||
| Model class for provider system configuration. | |||
| """ | |||
| enabled: bool | |||
| current_quota_type: Optional[ProviderQuotaType] = None | |||
| quota_configurations: list[QuotaConfiguration] = [] | |||
| credentials: Optional[dict] = None | |||
| class CustomProviderConfiguration(BaseModel): | |||
| """ | |||
| Model class for provider custom configuration. | |||
| """ | |||
| credentials: dict | |||
| class CustomModelConfiguration(BaseModel): | |||
| """ | |||
| Model class for provider custom model configuration. | |||
| """ | |||
| model: str | |||
| model_type: ModelType | |||
| credentials: dict | |||
| class CustomConfiguration(BaseModel): | |||
| """ | |||
| Model class for provider custom configuration. | |||
| """ | |||
| provider: Optional[CustomProviderConfiguration] = None | |||
| models: list[CustomModelConfiguration] = [] | |||
| @@ -0,0 +1,118 @@ | |||
| from enum import Enum | |||
| from typing import Any | |||
| from pydantic import BaseModel | |||
| from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk | |||
| class QueueEvent(Enum): | |||
| """ | |||
| QueueEvent enum | |||
| """ | |||
| MESSAGE = "message" | |||
| MESSAGE_REPLACE = "message-replace" | |||
| MESSAGE_END = "message-end" | |||
| RETRIEVER_RESOURCES = "retriever-resources" | |||
| ANNOTATION_REPLY = "annotation-reply" | |||
| AGENT_THOUGHT = "agent-thought" | |||
| ERROR = "error" | |||
| PING = "ping" | |||
| STOP = "stop" | |||
| class AppQueueEvent(BaseModel): | |||
| """ | |||
| QueueEvent entity | |||
| """ | |||
| event: QueueEvent | |||
| class QueueMessageEvent(AppQueueEvent): | |||
| """ | |||
| QueueMessageEvent entity | |||
| """ | |||
| event = QueueEvent.MESSAGE | |||
| chunk: LLMResultChunk | |||
| class QueueMessageReplaceEvent(AppQueueEvent): | |||
| """ | |||
| QueueMessageReplaceEvent entity | |||
| """ | |||
| event = QueueEvent.MESSAGE_REPLACE | |||
| text: str | |||
| class QueueRetrieverResourcesEvent(AppQueueEvent): | |||
| """ | |||
| QueueRetrieverResourcesEvent entity | |||
| """ | |||
| event = QueueEvent.RETRIEVER_RESOURCES | |||
| retriever_resources: list[dict] | |||
| class AnnotationReplyEvent(AppQueueEvent): | |||
| """ | |||
| AnnotationReplyEvent entity | |||
| """ | |||
| event = QueueEvent.ANNOTATION_REPLY | |||
| message_annotation_id: str | |||
| class QueueMessageEndEvent(AppQueueEvent): | |||
| """ | |||
| QueueMessageEndEvent entity | |||
| """ | |||
| event = QueueEvent.MESSAGE_END | |||
| llm_result: LLMResult | |||
| class QueueAgentThoughtEvent(AppQueueEvent): | |||
| """ | |||
| QueueAgentThoughtEvent entity | |||
| """ | |||
| event = QueueEvent.AGENT_THOUGHT | |||
| agent_thought_id: str | |||
| class QueueErrorEvent(AppQueueEvent): | |||
| """ | |||
| QueueErrorEvent entity | |||
| """ | |||
| event = QueueEvent.ERROR | |||
| error: Any | |||
| class QueuePingEvent(AppQueueEvent): | |||
| """ | |||
| QueuePingEvent entity | |||
| """ | |||
| event = QueueEvent.PING | |||
| class QueueStopEvent(AppQueueEvent): | |||
| """ | |||
| QueueStopEvent entity | |||
| """ | |||
| class StopBy(Enum): | |||
| """ | |||
| Stop by enum | |||
| """ | |||
| USER_MANUAL = "user-manual" | |||
| ANNOTATION_REPLY = "annotation-reply" | |||
| OUTPUT_MODERATION = "output-moderation" | |||
| event = QueueEvent.STOP | |||
| stopped_by: StopBy | |||
| class QueueMessage(BaseModel): | |||
| """ | |||
| QueueMessage entity | |||
| """ | |||
| task_id: str | |||
| message_id: str | |||
| conversation_id: str | |||
| app_mode: str | |||
| event: AppQueueEvent | |||
| @@ -14,26 +14,6 @@ class LLMBadRequestError(LLMError): | |||
| description = "Bad Request" | |||
| class LLMAPIConnectionError(LLMError): | |||
| """Raised when the LLM returns API connection error.""" | |||
| description = "API Connection Error" | |||
| class LLMAPIUnavailableError(LLMError): | |||
| """Raised when the LLM returns API unavailable error.""" | |||
| description = "API Unavailable Error" | |||
| class LLMRateLimitError(LLMError): | |||
| """Raised when the LLM returns rate limit error.""" | |||
| description = "Rate Limit Error" | |||
| class LLMAuthorizationError(LLMError): | |||
| """Raised when the LLM returns authorization error.""" | |||
| description = "Authorization Error" | |||
| class ProviderTokenNotInitError(Exception): | |||
| """ | |||
| Custom exception raised when the provider token is not initialized. | |||
| @@ -0,0 +1,35 @@ | |||
| { | |||
| "label": { | |||
| "en-US": "Weather Search", | |||
| "zh-Hans": "天气查询" | |||
| }, | |||
| "form_schema": [ | |||
| { | |||
| "type": "select", | |||
| "label": { | |||
| "en-US": "Temperature Unit", | |||
| "zh-Hans": "温度单位" | |||
| }, | |||
| "variable": "temperature_unit", | |||
| "required": true, | |||
| "options": [ | |||
| { | |||
| "label": { | |||
| "en-US": "Fahrenheit", | |||
| "zh-Hans": "华氏度" | |||
| }, | |||
| "value": "fahrenheit" | |||
| }, | |||
| { | |||
| "label": { | |||
| "en-US": "Centigrade", | |||
| "zh-Hans": "摄氏度" | |||
| }, | |||
| "value": "centigrade" | |||
| } | |||
| ], | |||
| "default": "centigrade", | |||
| "placeholder": "Please select temperature unit" | |||
| } | |||
| ] | |||
| } | |||
| @@ -0,0 +1,45 @@ | |||
| from typing import Optional | |||
| from core.external_data_tool.base import ExternalDataTool | |||
| class WeatherSearch(ExternalDataTool): | |||
| """ | |||
| The name of custom type must be unique, keep the same with directory and file name. | |||
| """ | |||
| name: str = "weather_search" | |||
| @classmethod | |||
| def validate_config(cls, tenant_id: str, config: dict) -> None: | |||
| """ | |||
| schema.json validation. It will be called when user save the config. | |||
| Example: | |||
| .. code-block:: python | |||
| config = { | |||
| "temperature_unit": "centigrade" | |||
| } | |||
| :param tenant_id: the id of workspace | |||
| :param config: the variables of form config | |||
| :return: | |||
| """ | |||
| if not config.get('temperature_unit'): | |||
| raise ValueError('temperature unit is required') | |||
| def query(self, inputs: dict, query: Optional[str] = None) -> str: | |||
| """ | |||
| Query the external data tool. | |||
| :param inputs: user inputs | |||
| :param query: the query of chat app | |||
| :return: the tool query result | |||
| """ | |||
| city = inputs.get('city') | |||
| temperature_unit = self.config.get('temperature_unit') | |||
| if temperature_unit == 'fahrenheit': | |||
| return f'Weather in {city} is 32°F' | |||
| else: | |||
| return f'Weather in {city} is 0°C' | |||
| @@ -0,0 +1,325 @@ | |||
| import logging | |||
| from typing import cast, Optional, List | |||
| from langchain import WikipediaAPIWrapper | |||
| from langchain.callbacks.base import BaseCallbackHandler | |||
| from langchain.tools import BaseTool, WikipediaQueryRun, Tool | |||
| from pydantic import BaseModel, Field | |||
| from core.agent.agent.agent_llm_callback import AgentLLMCallback | |||
| from core.agent.agent_executor import PlanningStrategy, AgentConfiguration, AgentExecutor | |||
| from core.application_queue_manager import ApplicationQueueManager | |||
| from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler | |||
| from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler | |||
| from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler | |||
| from core.entities.application_entities import ModelConfigEntity, InvokeFrom, \ | |||
| AgentEntity, AgentToolEntity, AppOrchestrationConfigEntity | |||
| from core.memory.token_buffer_memory import TokenBufferMemory | |||
| from core.model_runtime.entities.model_entities import ModelFeature, ModelType | |||
| from core.model_runtime.model_providers import model_provider_factory | |||
| from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel | |||
| from core.tool.current_datetime_tool import DatetimeTool | |||
| from core.tool.dataset_retriever_tool import DatasetRetrieverTool | |||
| from core.tool.provider.serpapi_provider import SerpAPIToolProvider | |||
| from core.tool.serpapi_wrapper import OptimizedSerpAPIWrapper, OptimizedSerpAPIInput | |||
| from core.tool.web_reader_tool import WebReaderTool | |||
| from extensions.ext_database import db | |||
| from models.dataset import Dataset | |||
| from models.model import Message | |||
| logger = logging.getLogger(__name__) | |||
| class AgentRunnerFeature: | |||
| def __init__(self, tenant_id: str, | |||
| app_orchestration_config: AppOrchestrationConfigEntity, | |||
| model_config: ModelConfigEntity, | |||
| config: AgentEntity, | |||
| queue_manager: ApplicationQueueManager, | |||
| message: Message, | |||
| user_id: str, | |||
| agent_llm_callback: AgentLLMCallback, | |||
| callback: AgentLoopGatherCallbackHandler, | |||
| memory: Optional[TokenBufferMemory] = None,) -> None: | |||
| """ | |||
| Agent runner | |||
| :param tenant_id: tenant id | |||
| :param app_orchestration_config: app orchestration config | |||
| :param model_config: model config | |||
| :param config: dataset config | |||
| :param queue_manager: queue manager | |||
| :param message: message | |||
| :param user_id: user id | |||
| :param agent_llm_callback: agent llm callback | |||
| :param callback: callback | |||
| :param memory: memory | |||
| """ | |||
| self.tenant_id = tenant_id | |||
| self.app_orchestration_config = app_orchestration_config | |||
| self.model_config = model_config | |||
| self.config = config | |||
| self.queue_manager = queue_manager | |||
| self.message = message | |||
| self.user_id = user_id | |||
| self.agent_llm_callback = agent_llm_callback | |||
| self.callback = callback | |||
| self.memory = memory | |||
| def run(self, query: str, | |||
| invoke_from: InvokeFrom) -> Optional[str]: | |||
| """ | |||
| Retrieve agent loop result. | |||
| :param query: query | |||
| :param invoke_from: invoke from | |||
| :return: | |||
| """ | |||
| provider = self.config.provider | |||
| model = self.config.model | |||
| tool_configs = self.config.tools | |||
| # check model is support tool calling | |||
| provider_instance = model_provider_factory.get_provider_instance(provider=provider) | |||
| model_type_instance = provider_instance.get_model_instance(ModelType.LLM) | |||
| model_type_instance = cast(LargeLanguageModel, model_type_instance) | |||
| # get model schema | |||
| model_schema = model_type_instance.get_model_schema( | |||
| model=model, | |||
| credentials=self.model_config.credentials | |||
| ) | |||
| if not model_schema: | |||
| return None | |||
| planning_strategy = PlanningStrategy.REACT | |||
| features = model_schema.features | |||
| if features: | |||
| if ModelFeature.TOOL_CALL in features \ | |||
| or ModelFeature.MULTI_TOOL_CALL in features: | |||
| planning_strategy = PlanningStrategy.FUNCTION_CALL | |||
| tools = self.to_tools( | |||
| tool_configs=tool_configs, | |||
| invoke_from=invoke_from, | |||
| callbacks=[self.callback, DifyStdOutCallbackHandler()], | |||
| ) | |||
| if len(tools) == 0: | |||
| return None | |||
| agent_configuration = AgentConfiguration( | |||
| strategy=planning_strategy, | |||
| model_config=self.model_config, | |||
| tools=tools, | |||
| memory=self.memory, | |||
| max_iterations=10, | |||
| max_execution_time=400.0, | |||
| early_stopping_method="generate", | |||
| agent_llm_callback=self.agent_llm_callback, | |||
| callbacks=[self.callback, DifyStdOutCallbackHandler()] | |||
| ) | |||
| agent_executor = AgentExecutor(agent_configuration) | |||
| try: | |||
| # check if should use agent | |||
| should_use_agent = agent_executor.should_use_agent(query) | |||
| if not should_use_agent: | |||
| return None | |||
| result = agent_executor.run(query) | |||
| return result.output | |||
| except Exception as ex: | |||
| logger.exception("agent_executor run failed") | |||
| return None | |||
| def to_tools(self, tool_configs: list[AgentToolEntity], | |||
| invoke_from: InvokeFrom, | |||
| callbacks: list[BaseCallbackHandler]) \ | |||
| -> Optional[List[BaseTool]]: | |||
| """ | |||
| Convert tool configs to tools | |||
| :param tool_configs: tool configs | |||
| :param invoke_from: invoke from | |||
| :param callbacks: callbacks | |||
| """ | |||
| tools = [] | |||
| for tool_config in tool_configs: | |||
| tool = None | |||
| if tool_config.tool_id == "dataset": | |||
| tool = self.to_dataset_retriever_tool( | |||
| tool_config=tool_config.config, | |||
| invoke_from=invoke_from | |||
| ) | |||
| elif tool_config.tool_id == "web_reader": | |||
| tool = self.to_web_reader_tool( | |||
| tool_config=tool_config.config, | |||
| invoke_from=invoke_from | |||
| ) | |||
| elif tool_config.tool_id == "google_search": | |||
| tool = self.to_google_search_tool( | |||
| tool_config=tool_config.config, | |||
| invoke_from=invoke_from | |||
| ) | |||
| elif tool_config.tool_id == "wikipedia": | |||
| tool = self.to_wikipedia_tool( | |||
| tool_config=tool_config.config, | |||
| invoke_from=invoke_from | |||
| ) | |||
| elif tool_config.tool_id == "current_datetime": | |||
| tool = self.to_current_datetime_tool( | |||
| tool_config=tool_config.config, | |||
| invoke_from=invoke_from | |||
| ) | |||
| if tool: | |||
| if tool.callbacks is not None: | |||
| tool.callbacks.extend(callbacks) | |||
| else: | |||
| tool.callbacks = callbacks | |||
| tools.append(tool) | |||
| return tools | |||
| def to_dataset_retriever_tool(self, tool_config: dict, | |||
| invoke_from: InvokeFrom) \ | |||
| -> Optional[BaseTool]: | |||
| """ | |||
| A dataset tool is a tool that can be used to retrieve information from a dataset | |||
| :param tool_config: tool config | |||
| :param invoke_from: invoke from | |||
| """ | |||
| show_retrieve_source = self.app_orchestration_config.show_retrieve_source | |||
| hit_callback = DatasetIndexToolCallbackHandler( | |||
| queue_manager=self.queue_manager, | |||
| app_id=self.message.app_id, | |||
| message_id=self.message.id, | |||
| user_id=self.user_id, | |||
| invoke_from=invoke_from | |||
| ) | |||
| # get dataset from dataset id | |||
| dataset = db.session.query(Dataset).filter( | |||
| Dataset.tenant_id == self.tenant_id, | |||
| Dataset.id == tool_config.get("id") | |||
| ).first() | |||
| # pass if dataset is not available | |||
| if not dataset: | |||
| return None | |||
| # pass if dataset is not available | |||
| if (dataset and dataset.available_document_count == 0 | |||
| and dataset.available_document_count == 0): | |||
| return None | |||
| # get retrieval model config | |||
| default_retrieval_model = { | |||
| 'search_method': 'semantic_search', | |||
| 'reranking_enable': False, | |||
| 'reranking_model': { | |||
| 'reranking_provider_name': '', | |||
| 'reranking_model_name': '' | |||
| }, | |||
| 'top_k': 2, | |||
| 'score_threshold_enabled': False | |||
| } | |||
| retrieval_model_config = dataset.retrieval_model \ | |||
| if dataset.retrieval_model else default_retrieval_model | |||
| # get top k | |||
| top_k = retrieval_model_config['top_k'] | |||
| # get score threshold | |||
| score_threshold = None | |||
| score_threshold_enabled = retrieval_model_config.get("score_threshold_enabled") | |||
| if score_threshold_enabled: | |||
| score_threshold = retrieval_model_config.get("score_threshold") | |||
| tool = DatasetRetrieverTool.from_dataset( | |||
| dataset=dataset, | |||
| top_k=top_k, | |||
| score_threshold=score_threshold, | |||
| hit_callbacks=[hit_callback], | |||
| return_resource=show_retrieve_source, | |||
| retriever_from=invoke_from.to_source() | |||
| ) | |||
| return tool | |||
| def to_web_reader_tool(self, tool_config: dict, | |||
| invoke_from: InvokeFrom) -> Optional[BaseTool]: | |||
| """ | |||
| A tool for reading web pages | |||
| :param tool_config: tool config | |||
| :param invoke_from: invoke from | |||
| :return: | |||
| """ | |||
| model_parameters = { | |||
| "temperature": 0, | |||
| "max_tokens": 500 | |||
| } | |||
| tool = WebReaderTool( | |||
| model_config=self.model_config, | |||
| model_parameters=model_parameters, | |||
| max_chunk_length=4000, | |||
| continue_reading=True | |||
| ) | |||
| return tool | |||
| def to_google_search_tool(self, tool_config: dict, | |||
| invoke_from: InvokeFrom) -> Optional[BaseTool]: | |||
| """ | |||
| A tool for performing a Google search and extracting snippets and webpages | |||
| :param tool_config: tool config | |||
| :param invoke_from: invoke from | |||
| :return: | |||
| """ | |||
| tool_provider = SerpAPIToolProvider(tenant_id=self.tenant_id) | |||
| func_kwargs = tool_provider.credentials_to_func_kwargs() | |||
| if not func_kwargs: | |||
| return None | |||
| tool = Tool( | |||
| name="google_search", | |||
| description="A tool for performing a Google search and extracting snippets and webpages " | |||
| "when you need to search for something you don't know or when your information " | |||
| "is not up to date. " | |||
| "Input should be a search query.", | |||
| func=OptimizedSerpAPIWrapper(**func_kwargs).run, | |||
| args_schema=OptimizedSerpAPIInput | |||
| ) | |||
| return tool | |||
| def to_current_datetime_tool(self, tool_config: dict, | |||
| invoke_from: InvokeFrom) -> Optional[BaseTool]: | |||
| """ | |||
| A tool for getting the current date and time | |||
| :param tool_config: tool config | |||
| :param invoke_from: invoke from | |||
| :return: | |||
| """ | |||
| return DatetimeTool() | |||
| def to_wikipedia_tool(self, tool_config: dict, | |||
| invoke_from: InvokeFrom) -> Optional[BaseTool]: | |||
| """ | |||
| A tool for searching Wikipedia | |||
| :param tool_config: tool config | |||
| :param invoke_from: invoke from | |||
| :return: | |||
| """ | |||
| class WikipediaInput(BaseModel): | |||
| query: str = Field(..., description="search query.") | |||
| return WikipediaQueryRun( | |||
| name="wikipedia", | |||
| api_wrapper=WikipediaAPIWrapper(doc_content_chars_max=4000), | |||
| args_schema=WikipediaInput | |||
| ) | |||
| @@ -0,0 +1,119 @@ | |||
| import logging | |||
| from typing import Optional | |||
| from flask import current_app | |||
| from core.embedding.cached_embedding import CacheEmbedding | |||
| from core.entities.application_entities import InvokeFrom | |||
| from core.index.vector_index.vector_index import VectorIndex | |||
| from core.model_manager import ModelManager | |||
| from core.model_runtime.entities.model_entities import ModelType | |||
| from extensions.ext_database import db | |||
| from models.dataset import Dataset | |||
| from models.model import App, Message, AppAnnotationSetting, MessageAnnotation | |||
| from services.annotation_service import AppAnnotationService | |||
| from services.dataset_service import DatasetCollectionBindingService | |||
| logger = logging.getLogger(__name__) | |||
| class AnnotationReplyFeature: | |||
| def query(self, app_record: App, | |||
| message: Message, | |||
| query: str, | |||
| user_id: str, | |||
| invoke_from: InvokeFrom) -> Optional[MessageAnnotation]: | |||
| """ | |||
| Query app annotations to reply | |||
| :param app_record: app record | |||
| :param message: message | |||
| :param query: query | |||
| :param user_id: user id | |||
| :param invoke_from: invoke from | |||
| :return: | |||
| """ | |||
| annotation_setting = db.session.query(AppAnnotationSetting).filter( | |||
| AppAnnotationSetting.app_id == app_record.id).first() | |||
| if not annotation_setting: | |||
| return None | |||
| collection_binding_detail = annotation_setting.collection_binding_detail | |||
| try: | |||
| score_threshold = annotation_setting.score_threshold or 1 | |||
| embedding_provider_name = collection_binding_detail.provider_name | |||
| embedding_model_name = collection_binding_detail.model_name | |||
| model_manager = ModelManager() | |||
| model_instance = model_manager.get_model_instance( | |||
| tenant_id=app_record.tenant_id, | |||
| provider=embedding_provider_name, | |||
| model_type=ModelType.TEXT_EMBEDDING, | |||
| model=embedding_model_name | |||
| ) | |||
| # get embedding model | |||
| embeddings = CacheEmbedding(model_instance) | |||
| dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( | |||
| embedding_provider_name, | |||
| embedding_model_name, | |||
| 'annotation' | |||
| ) | |||
| dataset = Dataset( | |||
| id=app_record.id, | |||
| tenant_id=app_record.tenant_id, | |||
| indexing_technique='high_quality', | |||
| embedding_model_provider=embedding_provider_name, | |||
| embedding_model=embedding_model_name, | |||
| collection_binding_id=dataset_collection_binding.id | |||
| ) | |||
| vector_index = VectorIndex( | |||
| dataset=dataset, | |||
| config=current_app.config, | |||
| embeddings=embeddings, | |||
| attributes=['doc_id', 'annotation_id', 'app_id'] | |||
| ) | |||
| documents = vector_index.search( | |||
| query=query, | |||
| search_type='similarity_score_threshold', | |||
| search_kwargs={ | |||
| 'k': 1, | |||
| 'score_threshold': score_threshold, | |||
| 'filter': { | |||
| 'group_id': [dataset.id] | |||
| } | |||
| } | |||
| ) | |||
| if documents: | |||
| annotation_id = documents[0].metadata['annotation_id'] | |||
| score = documents[0].metadata['score'] | |||
| annotation = AppAnnotationService.get_annotation_by_id(annotation_id) | |||
| if annotation: | |||
| if invoke_from in [InvokeFrom.SERVICE_API, InvokeFrom.WEB_APP]: | |||
| from_source = 'api' | |||
| else: | |||
| from_source = 'console' | |||
| # insert annotation history | |||
| AppAnnotationService.add_annotation_history(annotation.id, | |||
| app_record.id, | |||
| annotation.question, | |||
| annotation.content, | |||
| query, | |||
| user_id, | |||
| message.id, | |||
| from_source, | |||
| score) | |||
| return annotation | |||
| except Exception as e: | |||
| logger.warning(f'Query annotation failed, exception: {str(e)}.') | |||
| return None | |||
| return None | |||
| @@ -0,0 +1,181 @@ | |||
| from typing import cast, Optional, List | |||
| from langchain.tools import BaseTool | |||
| from core.agent.agent_executor import PlanningStrategy, AgentConfiguration, AgentExecutor | |||
| from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler | |||
| from core.entities.application_entities import DatasetEntity, ModelConfigEntity, InvokeFrom, DatasetRetrieveConfigEntity | |||
| from core.memory.token_buffer_memory import TokenBufferMemory | |||
| from core.model_runtime.entities.model_entities import ModelFeature | |||
| from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel | |||
| from core.tool.dataset_multi_retriever_tool import DatasetMultiRetrieverTool | |||
| from core.tool.dataset_retriever_tool import DatasetRetrieverTool | |||
| from extensions.ext_database import db | |||
| from models.dataset import Dataset | |||
| class DatasetRetrievalFeature: | |||
| def retrieve(self, tenant_id: str, | |||
| model_config: ModelConfigEntity, | |||
| config: DatasetEntity, | |||
| query: str, | |||
| invoke_from: InvokeFrom, | |||
| show_retrieve_source: bool, | |||
| hit_callback: DatasetIndexToolCallbackHandler, | |||
| memory: Optional[TokenBufferMemory] = None) -> Optional[str]: | |||
| """ | |||
| Retrieve dataset. | |||
| :param tenant_id: tenant id | |||
| :param model_config: model config | |||
| :param config: dataset config | |||
| :param query: query | |||
| :param invoke_from: invoke from | |||
| :param show_retrieve_source: show retrieve source | |||
| :param hit_callback: hit callback | |||
| :param memory: memory | |||
| :return: | |||
| """ | |||
| dataset_ids = config.dataset_ids | |||
| retrieve_config = config.retrieve_config | |||
| # check model is support tool calling | |||
| model_type_instance = model_config.provider_model_bundle.model_type_instance | |||
| model_type_instance = cast(LargeLanguageModel, model_type_instance) | |||
| # get model schema | |||
| model_schema = model_type_instance.get_model_schema( | |||
| model=model_config.model, | |||
| credentials=model_config.credentials | |||
| ) | |||
| if not model_schema: | |||
| return None | |||
| planning_strategy = PlanningStrategy.REACT_ROUTER | |||
| features = model_schema.features | |||
| if features: | |||
| if ModelFeature.TOOL_CALL in features \ | |||
| or ModelFeature.MULTI_TOOL_CALL in features: | |||
| planning_strategy = PlanningStrategy.ROUTER | |||
| dataset_retriever_tools = self.to_dataset_retriever_tool( | |||
| tenant_id=tenant_id, | |||
| dataset_ids=dataset_ids, | |||
| retrieve_config=retrieve_config, | |||
| return_resource=show_retrieve_source, | |||
| invoke_from=invoke_from, | |||
| hit_callback=hit_callback | |||
| ) | |||
| if len(dataset_retriever_tools) == 0: | |||
| return None | |||
| agent_configuration = AgentConfiguration( | |||
| strategy=planning_strategy, | |||
| model_config=model_config, | |||
| tools=dataset_retriever_tools, | |||
| memory=memory, | |||
| max_iterations=10, | |||
| max_execution_time=400.0, | |||
| early_stopping_method="generate" | |||
| ) | |||
| agent_executor = AgentExecutor(agent_configuration) | |||
| should_use_agent = agent_executor.should_use_agent(query) | |||
| if not should_use_agent: | |||
| return None | |||
| result = agent_executor.run(query) | |||
| return result.output | |||
| def to_dataset_retriever_tool(self, tenant_id: str, | |||
| dataset_ids: list[str], | |||
| retrieve_config: DatasetRetrieveConfigEntity, | |||
| return_resource: bool, | |||
| invoke_from: InvokeFrom, | |||
| hit_callback: DatasetIndexToolCallbackHandler) \ | |||
| -> Optional[List[BaseTool]]: | |||
| """ | |||
| A dataset tool is a tool that can be used to retrieve information from a dataset | |||
| :param tenant_id: tenant id | |||
| :param dataset_ids: dataset ids | |||
| :param retrieve_config: retrieve config | |||
| :param return_resource: return resource | |||
| :param invoke_from: invoke from | |||
| :param hit_callback: hit callback | |||
| """ | |||
| tools = [] | |||
| available_datasets = [] | |||
| for dataset_id in dataset_ids: | |||
| # get dataset from dataset id | |||
| dataset = db.session.query(Dataset).filter( | |||
| Dataset.tenant_id == tenant_id, | |||
| Dataset.id == dataset_id | |||
| ).first() | |||
| # pass if dataset is not available | |||
| if not dataset: | |||
| continue | |||
| # pass if dataset is not available | |||
| if (dataset and dataset.available_document_count == 0 | |||
| and dataset.available_document_count == 0): | |||
| continue | |||
| available_datasets.append(dataset) | |||
| if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE: | |||
| # get retrieval model config | |||
| default_retrieval_model = { | |||
| 'search_method': 'semantic_search', | |||
| 'reranking_enable': False, | |||
| 'reranking_model': { | |||
| 'reranking_provider_name': '', | |||
| 'reranking_model_name': '' | |||
| }, | |||
| 'top_k': 2, | |||
| 'score_threshold_enabled': False | |||
| } | |||
| for dataset in available_datasets: | |||
| retrieval_model_config = dataset.retrieval_model \ | |||
| if dataset.retrieval_model else default_retrieval_model | |||
| # get top k | |||
| top_k = retrieval_model_config['top_k'] | |||
| # get score threshold | |||
| score_threshold = None | |||
| score_threshold_enabled = retrieval_model_config.get("score_threshold_enabled") | |||
| if score_threshold_enabled: | |||
| score_threshold = retrieval_model_config.get("score_threshold") | |||
| tool = DatasetRetrieverTool.from_dataset( | |||
| dataset=dataset, | |||
| top_k=top_k, | |||
| score_threshold=score_threshold, | |||
| hit_callbacks=[hit_callback], | |||
| return_resource=return_resource, | |||
| retriever_from=invoke_from.to_source() | |||
| ) | |||
| tools.append(tool) | |||
| elif retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE: | |||
| tool = DatasetMultiRetrieverTool.from_dataset( | |||
| dataset_ids=[dataset.id for dataset in available_datasets], | |||
| tenant_id=tenant_id, | |||
| top_k=retrieve_config.top_k or 2, | |||
| score_threshold=(retrieve_config.score_threshold or 0.5) | |||
| if retrieve_config.reranking_model.get('score_threshold_enabled', False) else None, | |||
| hit_callbacks=[hit_callback], | |||
| return_resource=return_resource, | |||
| retriever_from=invoke_from.to_source(), | |||
| reranking_provider_name=retrieve_config.reranking_model.get('reranking_provider_name'), | |||
| reranking_model_name=retrieve_config.reranking_model.get('reranking_model_name') | |||
| ) | |||
| tools.append(tool) | |||
| return tools | |||
| @@ -0,0 +1,96 @@ | |||
| import concurrent | |||
| import json | |||
| import logging | |||
| from concurrent.futures import ThreadPoolExecutor | |||
| from typing import Tuple, Optional | |||
| from flask import current_app, Flask | |||
| from core.entities.application_entities import ExternalDataVariableEntity | |||
| from core.external_data_tool.factory import ExternalDataToolFactory | |||
| logger = logging.getLogger(__name__) | |||
| class ExternalDataFetchFeature: | |||
| def fetch(self, tenant_id: str, | |||
| app_id: str, | |||
| external_data_tools: list[ExternalDataVariableEntity], | |||
| inputs: dict, | |||
| query: str) -> dict: | |||
| """ | |||
| Fill in variable inputs from external data tools if exists. | |||
| :param tenant_id: workspace id | |||
| :param app_id: app id | |||
| :param external_data_tools: external data tools configs | |||
| :param inputs: the inputs | |||
| :param query: the query | |||
| :return: the filled inputs | |||
| """ | |||
| # Group tools by type and config | |||
| grouped_tools = {} | |||
| for tool in external_data_tools: | |||
| tool_key = (tool.type, json.dumps(tool.config, sort_keys=True)) | |||
| grouped_tools.setdefault(tool_key, []).append(tool) | |||
| results = {} | |||
| with ThreadPoolExecutor() as executor: | |||
| futures = {} | |||
| for tool in external_data_tools: | |||
| future = executor.submit( | |||
| self._query_external_data_tool, | |||
| current_app._get_current_object(), | |||
| tenant_id, | |||
| app_id, | |||
| tool, | |||
| inputs, | |||
| query | |||
| ) | |||
| futures[future] = tool | |||
| for future in concurrent.futures.as_completed(futures): | |||
| tool_variable, result = future.result() | |||
| results[tool_variable] = result | |||
| inputs.update(results) | |||
| return inputs | |||
| def _query_external_data_tool(self, flask_app: Flask, | |||
| tenant_id: str, | |||
| app_id: str, | |||
| external_data_tool: ExternalDataVariableEntity, | |||
| inputs: dict, | |||
| query: str) -> Tuple[Optional[str], Optional[str]]: | |||
| """ | |||
| Query external data tool. | |||
| :param flask_app: flask app | |||
| :param tenant_id: tenant id | |||
| :param app_id: app id | |||
| :param external_data_tool: external data tool | |||
| :param inputs: inputs | |||
| :param query: query | |||
| :return: | |||
| """ | |||
| with flask_app.app_context(): | |||
| tool_variable = external_data_tool.variable | |||
| tool_type = external_data_tool.type | |||
| tool_config = external_data_tool.config | |||
| external_data_tool_factory = ExternalDataToolFactory( | |||
| name=tool_type, | |||
| tenant_id=tenant_id, | |||
| app_id=app_id, | |||
| variable=tool_variable, | |||
| config=tool_config | |||
| ) | |||
| # query external data tool | |||
| result = external_data_tool_factory.query( | |||
| inputs=inputs, | |||
| query=query | |||
| ) | |||
| return tool_variable, result | |||
| @@ -0,0 +1,32 @@ | |||
| import logging | |||
| from core.entities.application_entities import ApplicationGenerateEntity | |||
| from core.helper import moderation | |||
| from core.model_runtime.entities.message_entities import PromptMessage | |||
| logger = logging.getLogger(__name__) | |||
| class HostingModerationFeature: | |||
| def check(self, application_generate_entity: ApplicationGenerateEntity, | |||
| prompt_messages: list[PromptMessage]) -> bool: | |||
| """ | |||
| Check hosting moderation | |||
| :param application_generate_entity: application generate entity | |||
| :param prompt_messages: prompt messages | |||
| :return: | |||
| """ | |||
| app_orchestration_config = application_generate_entity.app_orchestration_config_entity | |||
| model_config = app_orchestration_config.model_config | |||
| text = "" | |||
| for prompt_message in prompt_messages: | |||
| if isinstance(prompt_message.content, str): | |||
| text += prompt_message.content + "\n" | |||
| moderation_result = moderation.check_moderation( | |||
| model_config, | |||
| text | |||
| ) | |||
| return moderation_result | |||
| @@ -0,0 +1,50 @@ | |||
| import logging | |||
| from typing import Tuple | |||
| from core.entities.application_entities import AppOrchestrationConfigEntity | |||
| from core.moderation.base import ModerationAction, ModerationException | |||
| from core.moderation.factory import ModerationFactory | |||
| logger = logging.getLogger(__name__) | |||
| class ModerationFeature: | |||
| def check(self, app_id: str, | |||
| tenant_id: str, | |||
| app_orchestration_config_entity: AppOrchestrationConfigEntity, | |||
| inputs: dict, | |||
| query: str) -> Tuple[bool, dict, str]: | |||
| """ | |||
| Process sensitive_word_avoidance. | |||
| :param app_id: app id | |||
| :param tenant_id: tenant id | |||
| :param app_orchestration_config_entity: app orchestration config entity | |||
| :param inputs: inputs | |||
| :param query: query | |||
| :return: | |||
| """ | |||
| if not app_orchestration_config_entity.sensitive_word_avoidance: | |||
| return False, inputs, query | |||
| sensitive_word_avoidance_config = app_orchestration_config_entity.sensitive_word_avoidance | |||
| moderation_type = sensitive_word_avoidance_config.type | |||
| moderation_factory = ModerationFactory( | |||
| name=moderation_type, | |||
| app_id=app_id, | |||
| tenant_id=tenant_id, | |||
| config=sensitive_word_avoidance_config.config | |||
| ) | |||
| moderation_result = moderation_factory.moderation_for_inputs(inputs, query) | |||
| if not moderation_result.flagged: | |||
| return False, inputs, query | |||
| if moderation_result.action == ModerationAction.DIRECT_OUTPUT: | |||
| raise ModerationException(moderation_result.preset_response) | |||
| elif moderation_result.action == ModerationAction.OVERRIDED: | |||
| inputs = moderation_result.inputs | |||
| query = moderation_result.query | |||
| return True, inputs, query | |||
| @@ -4,7 +4,7 @@ from typing import Optional | |||
| from pydantic import BaseModel | |||
| from core.file.upload_file_parser import UploadFileParser | |||
| from core.model_providers.models.entity.message import PromptMessageFile, ImagePromptMessageFile | |||
| from core.model_runtime.entities.message_entities import ImagePromptMessageContent | |||
| from extensions.ext_database import db | |||
| from models.model import UploadFile | |||
| @@ -50,14 +50,14 @@ class FileObj(BaseModel): | |||
| return self._get_data(force_url=True) | |||
| @property | |||
| def prompt_message_file(self) -> PromptMessageFile: | |||
| def prompt_message_content(self) -> ImagePromptMessageContent: | |||
| if self.type == FileType.IMAGE: | |||
| image_config = self.file_config.get('image') | |||
| return ImagePromptMessageFile( | |||
| return ImagePromptMessageContent( | |||
| data=self.data, | |||
| detail=ImagePromptMessageFile.DETAIL.HIGH | |||
| if image_config.get("detail") == "high" else ImagePromptMessageFile.DETAIL.LOW | |||
| detail=ImagePromptMessageContent.DETAIL.HIGH | |||
| if image_config.get("detail") == "high" else ImagePromptMessageContent.DETAIL.LOW | |||
| ) | |||
| def _get_data(self, force_url: bool = False) -> Optional[str]: | |||
| @@ -3,10 +3,10 @@ import logging | |||
| from langchain.schema import OutputParserException | |||
| from core.model_providers.error import LLMError, ProviderTokenNotInitError | |||
| from core.model_providers.model_factory import ModelFactory | |||
| from core.model_providers.models.entity.message import PromptMessage, MessageType | |||
| from core.model_providers.models.entity.model_params import ModelKwargs | |||
| from core.model_manager import ModelManager | |||
| from core.model_runtime.entities.message_entities import UserPromptMessage, SystemPromptMessage | |||
| from core.model_runtime.entities.model_entities import ModelType | |||
| from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError | |||
| from core.prompt.output_parser.rule_config_generator import RuleConfigGeneratorOutputParser | |||
| from core.prompt.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser | |||
| @@ -26,17 +26,22 @@ class LLMGenerator: | |||
| prompt += query + "\n" | |||
| model_instance = ModelFactory.get_text_generation_model( | |||
| model_manager = ModelManager() | |||
| model_instance = model_manager.get_default_model_instance( | |||
| tenant_id=tenant_id, | |||
| model_kwargs=ModelKwargs( | |||
| temperature=1, | |||
| max_tokens=100 | |||
| ) | |||
| model_type=ModelType.LLM, | |||
| ) | |||
| prompts = [PromptMessage(content=prompt)] | |||
| response = model_instance.run(prompts) | |||
| answer = response.content | |||
| prompts = [UserPromptMessage(content=prompt)] | |||
| response = model_instance.invoke_llm( | |||
| prompt_messages=prompts, | |||
| model_parameters={ | |||
| "max_tokens": 100, | |||
| "temperature": 1 | |||
| }, | |||
| stream=False | |||
| ) | |||
| answer = response.message.content | |||
| result_dict = json.loads(answer) | |||
| answer = result_dict['Your Output'] | |||
| @@ -62,22 +67,28 @@ class LLMGenerator: | |||
| }) | |||
| try: | |||
| model_instance = ModelFactory.get_text_generation_model( | |||
| model_manager = ModelManager() | |||
| model_instance = model_manager.get_default_model_instance( | |||
| tenant_id=tenant_id, | |||
| model_kwargs=ModelKwargs( | |||
| max_tokens=256, | |||
| temperature=0 | |||
| ) | |||
| model_type=ModelType.LLM, | |||
| ) | |||
| except ProviderTokenNotInitError: | |||
| except InvokeAuthorizationError: | |||
| return [] | |||
| prompt_messages = [PromptMessage(content=prompt)] | |||
| prompt_messages = [UserPromptMessage(content=prompt)] | |||
| try: | |||
| output = model_instance.run(prompt_messages) | |||
| questions = output_parser.parse(output.content) | |||
| except LLMError: | |||
| response = model_instance.invoke_llm( | |||
| prompt_messages=prompt_messages, | |||
| model_parameters={ | |||
| "max_tokens": 256, | |||
| "temperature": 0 | |||
| }, | |||
| stream=False | |||
| ) | |||
| questions = output_parser.parse(response.message.content) | |||
| except InvokeError: | |||
| questions = [] | |||
| except Exception as e: | |||
| logging.exception(e) | |||
| @@ -105,20 +116,26 @@ class LLMGenerator: | |||
| remove_template_variables=False | |||
| ) | |||
| model_instance = ModelFactory.get_text_generation_model( | |||
| model_manager = ModelManager() | |||
| model_instance = model_manager.get_default_model_instance( | |||
| tenant_id=tenant_id, | |||
| model_kwargs=ModelKwargs( | |||
| max_tokens=512, | |||
| temperature=0 | |||
| ) | |||
| model_type=ModelType.LLM, | |||
| ) | |||
| prompt_messages = [PromptMessage(content=prompt)] | |||
| prompt_messages = [UserPromptMessage(content=prompt)] | |||
| try: | |||
| output = model_instance.run(prompt_messages) | |||
| rule_config = output_parser.parse(output.content) | |||
| except LLMError as e: | |||
| response = model_instance.invoke_llm( | |||
| prompt_messages=prompt_messages, | |||
| model_parameters={ | |||
| "max_tokens": 512, | |||
| "temperature": 0 | |||
| }, | |||
| stream=False | |||
| ) | |||
| rule_config = output_parser.parse(response.message.content) | |||
| except InvokeError as e: | |||
| raise e | |||
| except OutputParserException: | |||
| raise ValueError('Please give a valid input for intended audience or hoping to solve problems.') | |||
| @@ -136,18 +153,24 @@ class LLMGenerator: | |||
| def generate_qa_document(cls, tenant_id: str, query, document_language: str): | |||
| prompt = GENERATOR_QA_PROMPT.format(language=document_language) | |||
| model_instance = ModelFactory.get_text_generation_model( | |||
| model_manager = ModelManager() | |||
| model_instance = model_manager.get_default_model_instance( | |||
| tenant_id=tenant_id, | |||
| model_kwargs=ModelKwargs( | |||
| max_tokens=2000 | |||
| ) | |||
| model_type=ModelType.LLM, | |||
| ) | |||
| prompts = [ | |||
| PromptMessage(content=prompt, type=MessageType.SYSTEM), | |||
| PromptMessage(content=query) | |||
| prompt_messages = [ | |||
| SystemPromptMessage(content=prompt), | |||
| UserPromptMessage(content=query) | |||
| ] | |||
| response = model_instance.run(prompts) | |||
| answer = response.content | |||
| response = model_instance.invoke_llm( | |||
| prompt_messages=prompt_messages, | |||
| model_parameters={ | |||
| "max_tokens": 2000 | |||
| }, | |||
| stream=False | |||
| ) | |||
| answer = response.message.content | |||
| return answer.strip() | |||
| @@ -18,3 +18,17 @@ def encrypt_token(tenant_id: str, token: str): | |||
| def decrypt_token(tenant_id: str, token: str): | |||
| return rsa.decrypt(base64.b64decode(token), tenant_id) | |||
| def batch_decrypt_token(tenant_id: str, tokens: list[str]): | |||
| rsa_key, cipher_rsa = rsa.get_decrypt_decoding(tenant_id) | |||
| return [rsa.decrypt_token_with_decoding(base64.b64decode(token), rsa_key, cipher_rsa) for token in tokens] | |||
| def get_decrypt_decoding(tenant_id: str): | |||
| return rsa.get_decrypt_decoding(tenant_id) | |||
| def decrypt_token_with_decoding(token: str, rsa_key, cipher_rsa): | |||
| return rsa.decrypt_token_with_decoding(base64.b64decode(token), rsa_key, cipher_rsa) | |||
| @@ -0,0 +1,22 @@ | |||
| from collections import OrderedDict | |||
| from typing import Any | |||
| class LRUCache: | |||
| def __init__(self, capacity: int): | |||
| self.cache = OrderedDict() | |||
| self.capacity = capacity | |||
| def get(self, key: Any) -> Any: | |||
| if key not in self.cache: | |||
| return None | |||
| else: | |||
| self.cache.move_to_end(key) # move the key to the end of the OrderedDict | |||
| return self.cache[key] | |||
| def put(self, key: Any, value: Any) -> None: | |||
| if key in self.cache: | |||
| self.cache.move_to_end(key) | |||
| self.cache[key] = value | |||
| if len(self.cache) > self.capacity: | |||
| self.cache.popitem(last=False) # pop the first item | |||
| @@ -1,18 +1,27 @@ | |||
| import logging | |||
| import random | |||
| import openai | |||
| from core.model_providers.error import LLMBadRequestError | |||
| from core.model_providers.providers.base import BaseModelProvider | |||
| from core.model_providers.providers.hosted import hosted_config, hosted_model_providers | |||
| from core.entities.application_entities import ModelConfigEntity | |||
| from core.model_runtime.errors.invoke import InvokeBadRequestError | |||
| from core.model_runtime.model_providers.openai.moderation.moderation import OpenAIModerationModel | |||
| from extensions.ext_hosting_provider import hosting_configuration | |||
| from models.provider import ProviderType | |||
| logger = logging.getLogger(__name__) | |||
| def check_moderation(model_config: ModelConfigEntity, text: str) -> bool: | |||
| moderation_config = hosting_configuration.moderation_config | |||
| if (moderation_config and moderation_config.enabled is True | |||
| and 'openai' in hosting_configuration.provider_map | |||
| and hosting_configuration.provider_map['openai'].enabled is True | |||
| ): | |||
| using_provider_type = model_config.provider_model_bundle.configuration.using_provider_type | |||
| provider_name = model_config.provider | |||
| if using_provider_type == ProviderType.SYSTEM \ | |||
| and provider_name in moderation_config.providers: | |||
| hosting_openai_config = hosting_configuration.provider_map['openai'] | |||
| def check_moderation(model_provider: BaseModelProvider, text: str) -> bool: | |||
| if hosted_config.moderation.enabled is True and hosted_model_providers.openai: | |||
| if model_provider.provider.provider_type == ProviderType.SYSTEM.value \ | |||
| and model_provider.provider_name in hosted_config.moderation.providers: | |||
| # 2000 text per chunk | |||
| length = 2000 | |||
| text_chunks = [text[i:i + length] for i in range(0, len(text), length)] | |||
| @@ -23,14 +32,17 @@ def check_moderation(model_provider: BaseModelProvider, text: str) -> bool: | |||
| text_chunk = random.choice(text_chunks) | |||
| try: | |||
| moderation_result = openai.Moderation.create(input=text_chunk, | |||
| api_key=hosted_model_providers.openai.api_key) | |||
| model_type_instance = OpenAIModerationModel() | |||
| moderation_result = model_type_instance.invoke( | |||
| model='text-moderation-stable', | |||
| credentials=hosting_openai_config.credentials, | |||
| text=text_chunk | |||
| ) | |||
| if moderation_result is True: | |||
| return True | |||
| except Exception as ex: | |||
| logging.exception(ex) | |||
| raise LLMBadRequestError('Rate limit exceeded, please try again later.') | |||
| for result in moderation_result.results: | |||
| if result['flagged'] is True: | |||
| return False | |||
| logger.exception(ex) | |||
| raise InvokeBadRequestError('Rate limit exceeded, please try again later.') | |||
| return True | |||
| return False | |||
| @@ -0,0 +1,213 @@ | |||
| import os | |||
| from typing import Optional | |||
| from flask import Flask | |||
| from pydantic import BaseModel | |||
| from core.entities.provider_entities import QuotaUnit | |||
| from models.provider import ProviderQuotaType | |||
| class HostingQuota(BaseModel): | |||
| quota_type: ProviderQuotaType | |||
| restrict_llms: list[str] = [] | |||
| class TrialHostingQuota(HostingQuota): | |||
| quota_type: ProviderQuotaType = ProviderQuotaType.TRIAL | |||
| quota_limit: int = 0 | |||
| """Quota limit for the hosting provider models. -1 means unlimited.""" | |||
| class PaidHostingQuota(HostingQuota): | |||
| quota_type: ProviderQuotaType = ProviderQuotaType.PAID | |||
| stripe_price_id: str = None | |||
| increase_quota: int = 1 | |||
| min_quantity: int = 20 | |||
| max_quantity: int = 100 | |||
| class FreeHostingQuota(HostingQuota): | |||
| quota_type: ProviderQuotaType = ProviderQuotaType.FREE | |||
| class HostingProvider(BaseModel): | |||
| enabled: bool = False | |||
| credentials: Optional[dict] = None | |||
| quota_unit: Optional[QuotaUnit] = None | |||
| quotas: list[HostingQuota] = [] | |||
| class HostedModerationConfig(BaseModel): | |||
| enabled: bool = False | |||
| providers: list[str] = [] | |||
| class HostingConfiguration: | |||
| provider_map: dict[str, HostingProvider] = {} | |||
| moderation_config: HostedModerationConfig = None | |||
| def init_app(self, app: Flask): | |||
| if app.config.get('EDITION') != 'CLOUD': | |||
| return | |||
| self.provider_map["openai"] = self.init_openai() | |||
| self.provider_map["anthropic"] = self.init_anthropic() | |||
| self.provider_map["minimax"] = self.init_minimax() | |||
| self.provider_map["spark"] = self.init_spark() | |||
| self.provider_map["zhipuai"] = self.init_zhipuai() | |||
| self.moderation_config = self.init_moderation_config() | |||
| def init_openai(self) -> HostingProvider: | |||
| quota_unit = QuotaUnit.TIMES | |||
| if os.environ.get("HOSTED_OPENAI_ENABLED") and os.environ.get("HOSTED_OPENAI_ENABLED").lower() == 'true': | |||
| credentials = { | |||
| "openai_api_key": os.environ.get("HOSTED_OPENAI_API_KEY"), | |||
| } | |||
| if os.environ.get("HOSTED_OPENAI_API_BASE"): | |||
| credentials["openai_api_base"] = os.environ.get("HOSTED_OPENAI_API_BASE") | |||
| if os.environ.get("HOSTED_OPENAI_API_ORGANIZATION"): | |||
| credentials["openai_organization"] = os.environ.get("HOSTED_OPENAI_API_ORGANIZATION") | |||
| quotas = [] | |||
| hosted_quota_limit = int(os.environ.get("HOSTED_OPENAI_QUOTA_LIMIT", "200")) | |||
| if hosted_quota_limit != -1 or hosted_quota_limit > 0: | |||
| trial_quota = TrialHostingQuota( | |||
| quota_limit=hosted_quota_limit, | |||
| restrict_llms=[ | |||
| "gpt-3.5-turbo", | |||
| "gpt-3.5-turbo-1106", | |||
| "gpt-3.5-turbo-instruct", | |||
| "gpt-3.5-turbo-16k", | |||
| "text-davinci-003" | |||
| ] | |||
| ) | |||
| quotas.append(trial_quota) | |||
| if os.environ.get("HOSTED_OPENAI_PAID_ENABLED") and os.environ.get( | |||
| "HOSTED_OPENAI_PAID_ENABLED").lower() == 'true': | |||
| paid_quota = PaidHostingQuota( | |||
| stripe_price_id=os.environ.get("HOSTED_OPENAI_PAID_STRIPE_PRICE_ID"), | |||
| increase_quota=int(os.environ.get("HOSTED_OPENAI_PAID_INCREASE_QUOTA", "1")), | |||
| min_quantity=int(os.environ.get("HOSTED_OPENAI_PAID_MIN_QUANTITY", "1")), | |||
| max_quantity=int(os.environ.get("HOSTED_OPENAI_PAID_MAX_QUANTITY", "1")) | |||
| ) | |||
| quotas.append(paid_quota) | |||
| return HostingProvider( | |||
| enabled=True, | |||
| credentials=credentials, | |||
| quota_unit=quota_unit, | |||
| quotas=quotas | |||
| ) | |||
| return HostingProvider( | |||
| enabled=False, | |||
| quota_unit=quota_unit, | |||
| ) | |||
| def init_anthropic(self) -> HostingProvider: | |||
| quota_unit = QuotaUnit.TOKENS | |||
| if os.environ.get("HOSTED_ANTHROPIC_ENABLED") and os.environ.get("HOSTED_ANTHROPIC_ENABLED").lower() == 'true': | |||
| credentials = { | |||
| "anthropic_api_key": os.environ.get("HOSTED_ANTHROPIC_API_KEY"), | |||
| } | |||
| if os.environ.get("HOSTED_ANTHROPIC_API_BASE"): | |||
| credentials["anthropic_api_url"] = os.environ.get("HOSTED_ANTHROPIC_API_BASE") | |||
| quotas = [] | |||
| hosted_quota_limit = int(os.environ.get("HOSTED_ANTHROPIC_QUOTA_LIMIT", "0")) | |||
| if hosted_quota_limit != -1 or hosted_quota_limit > 0: | |||
| trial_quota = TrialHostingQuota( | |||
| quota_limit=hosted_quota_limit | |||
| ) | |||
| quotas.append(trial_quota) | |||
| if os.environ.get("HOSTED_ANTHROPIC_PAID_ENABLED") and os.environ.get( | |||
| "HOSTED_ANTHROPIC_PAID_ENABLED").lower() == 'true': | |||
| paid_quota = PaidHostingQuota( | |||
| stripe_price_id=os.environ.get("HOSTED_ANTHROPIC_PAID_STRIPE_PRICE_ID"), | |||
| increase_quota=int(os.environ.get("HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA", "1000000")), | |||
| min_quantity=int(os.environ.get("HOSTED_ANTHROPIC_PAID_MIN_QUANTITY", "20")), | |||
| max_quantity=int(os.environ.get("HOSTED_ANTHROPIC_PAID_MAX_QUANTITY", "100")) | |||
| ) | |||
| quotas.append(paid_quota) | |||
| return HostingProvider( | |||
| enabled=True, | |||
| credentials=credentials, | |||
| quota_unit=quota_unit, | |||
| quotas=quotas | |||
| ) | |||
| return HostingProvider( | |||
| enabled=False, | |||
| quota_unit=quota_unit, | |||
| ) | |||
| def init_minimax(self) -> HostingProvider: | |||
| quota_unit = QuotaUnit.TOKENS | |||
| if os.environ.get("HOSTED_MINIMAX_ENABLED") and os.environ.get("HOSTED_MINIMAX_ENABLED").lower() == 'true': | |||
| quotas = [FreeHostingQuota()] | |||
| return HostingProvider( | |||
| enabled=True, | |||
| credentials=None, # use credentials from the provider | |||
| quota_unit=quota_unit, | |||
| quotas=quotas | |||
| ) | |||
| return HostingProvider( | |||
| enabled=False, | |||
| quota_unit=quota_unit, | |||
| ) | |||
| def init_spark(self) -> HostingProvider: | |||
| quota_unit = QuotaUnit.TOKENS | |||
| if os.environ.get("HOSTED_SPARK_ENABLED") and os.environ.get("HOSTED_SPARK_ENABLED").lower() == 'true': | |||
| quotas = [FreeHostingQuota()] | |||
| return HostingProvider( | |||
| enabled=True, | |||
| credentials=None, # use credentials from the provider | |||
| quota_unit=quota_unit, | |||
| quotas=quotas | |||
| ) | |||
| return HostingProvider( | |||
| enabled=False, | |||
| quota_unit=quota_unit, | |||
| ) | |||
| def init_zhipuai(self) -> HostingProvider: | |||
| quota_unit = QuotaUnit.TOKENS | |||
| if os.environ.get("HOSTED_ZHIPUAI_ENABLED") and os.environ.get("HOSTED_ZHIPUAI_ENABLED").lower() == 'true': | |||
| quotas = [FreeHostingQuota()] | |||
| return HostingProvider( | |||
| enabled=True, | |||
| credentials=None, # use credentials from the provider | |||
| quota_unit=quota_unit, | |||
| quotas=quotas | |||
| ) | |||
| return HostingProvider( | |||
| enabled=False, | |||
| quota_unit=quota_unit, | |||
| ) | |||
| def init_moderation_config(self) -> HostedModerationConfig: | |||
| if os.environ.get("HOSTED_MODERATION_ENABLED") and os.environ.get("HOSTED_MODERATION_ENABLED").lower() == 'true' \ | |||
| and os.environ.get("HOSTED_MODERATION_PROVIDERS"): | |||
| return HostedModerationConfig( | |||
| enabled=True, | |||
| providers=os.environ.get("HOSTED_MODERATION_PROVIDERS").split(',') | |||
| ) | |||
| return HostedModerationConfig( | |||
| enabled=False | |||
| ) | |||
| @@ -1,18 +1,12 @@ | |||
| import json | |||
| from flask import current_app | |||
| from langchain.embeddings import OpenAIEmbeddings | |||
| from core.embedding.cached_embedding import CacheEmbedding | |||
| from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex, KeywordTableConfig | |||
| from core.index.vector_index.vector_index import VectorIndex | |||
| from core.model_providers.model_factory import ModelFactory | |||
| from core.model_providers.models.embedding.openai_embedding import OpenAIEmbedding | |||
| from core.model_providers.models.entity.model_params import ModelKwargs | |||
| from core.model_providers.models.llm.openai_model import OpenAIModel | |||
| from core.model_providers.providers.openai_provider import OpenAIProvider | |||
| from core.model_manager import ModelManager | |||
| from core.model_runtime.entities.model_entities import ModelType | |||
| from models.dataset import Dataset | |||
| from models.provider import Provider, ProviderType | |||
| class IndexBuilder: | |||
| @@ -22,10 +16,12 @@ class IndexBuilder: | |||
| if not ignore_high_quality_check and dataset.indexing_technique != 'high_quality': | |||
| return None | |||
| embedding_model = ModelFactory.get_embedding_model( | |||
| model_manager = ModelManager() | |||
| embedding_model = model_manager.get_model_instance( | |||
| tenant_id=dataset.tenant_id, | |||
| model_provider_name=dataset.embedding_model_provider, | |||
| model_name=dataset.embedding_model | |||
| model_type=ModelType.TEXT_EMBEDDING, | |||
| provider=dataset.embedding_model_provider, | |||
| model=dataset.embedding_model | |||
| ) | |||
| embeddings = CacheEmbedding(embedding_model) | |||
| @@ -18,9 +18,11 @@ from core.data_loader.loader.notion import NotionLoader | |||
| from core.docstore.dataset_docstore import DatasetDocumentStore | |||
| from core.generator.llm_generator import LLMGenerator | |||
| from core.index.index import IndexBuilder | |||
| from core.model_providers.error import ProviderTokenNotInitError | |||
| from core.model_providers.model_factory import ModelFactory | |||
| from core.model_providers.models.entity.message import MessageType | |||
| from core.model_manager import ModelManager | |||
| from core.errors.error import ProviderTokenNotInitError | |||
| from core.model_runtime.entities.model_entities import ModelType, PriceType | |||
| from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel | |||
| from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel | |||
| from core.spiltter.fixed_text_splitter import FixedRecursiveCharacterTextSplitter | |||
| from extensions.ext_database import db | |||
| from extensions.ext_redis import redis_client | |||
| @@ -36,6 +38,7 @@ class IndexingRunner: | |||
| def __init__(self): | |||
| self.storage = storage | |||
| self.model_manager = ModelManager() | |||
| def run(self, dataset_documents: List[DatasetDocument]): | |||
| """Run the indexing process.""" | |||
| @@ -210,7 +213,7 @@ class IndexingRunner: | |||
| """ | |||
| Estimate the indexing for the document. | |||
| """ | |||
| embedding_model = None | |||
| embedding_model_instance = None | |||
| if dataset_id: | |||
| dataset = Dataset.query.filter_by( | |||
| id=dataset_id | |||
| @@ -218,15 +221,17 @@ class IndexingRunner: | |||
| if not dataset: | |||
| raise ValueError('Dataset not found.') | |||
| if dataset.indexing_technique == 'high_quality' or indexing_technique == 'high_quality': | |||
| embedding_model = ModelFactory.get_embedding_model( | |||
| tenant_id=dataset.tenant_id, | |||
| model_provider_name=dataset.embedding_model_provider, | |||
| model_name=dataset.embedding_model | |||
| embedding_model_instance = self.model_manager.get_model_instance( | |||
| tenant_id=tenant_id, | |||
| provider=dataset.embedding_model_provider, | |||
| model_type=ModelType.TEXT_EMBEDDING, | |||
| model=dataset.embedding_model | |||
| ) | |||
| else: | |||
| if indexing_technique == 'high_quality': | |||
| embedding_model = ModelFactory.get_embedding_model( | |||
| tenant_id=tenant_id | |||
| embedding_model_instance = self.model_manager.get_default_model_instance( | |||
| tenant_id=tenant_id, | |||
| model_type=ModelType.TEXT_EMBEDDING, | |||
| ) | |||
| tokens = 0 | |||
| preview_texts = [] | |||
| @@ -255,32 +260,56 @@ class IndexingRunner: | |||
| for document in documents: | |||
| if len(preview_texts) < 5: | |||
| preview_texts.append(document.page_content) | |||
| if indexing_technique == 'high_quality' or embedding_model: | |||
| tokens += embedding_model.get_num_tokens(self.filter_string(document.page_content)) | |||
| if indexing_technique == 'high_quality' or embedding_model_instance: | |||
| embedding_model_type_instance = embedding_model_instance.model_type_instance | |||
| embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_type_instance) | |||
| tokens += embedding_model_type_instance.get_num_tokens( | |||
| model=embedding_model_instance.model, | |||
| credentials=embedding_model_instance.credentials, | |||
| texts=[self.filter_string(document.page_content)] | |||
| ) | |||
| if doc_form and doc_form == 'qa_model': | |||
| text_generation_model = ModelFactory.get_text_generation_model( | |||
| tenant_id=tenant_id | |||
| model_instance = self.model_manager.get_default_model_instance( | |||
| tenant_id=tenant_id, | |||
| model_type=ModelType.LLM | |||
| ) | |||
| model_type_instance = model_instance.model_type_instance | |||
| model_type_instance = cast(LargeLanguageModel, model_type_instance) | |||
| if len(preview_texts) > 0: | |||
| # qa model document | |||
| response = LLMGenerator.generate_qa_document(current_user.current_tenant_id, preview_texts[0], | |||
| doc_language) | |||
| document_qa_list = self.format_split_text(response) | |||
| price_info = model_type_instance.get_price( | |||
| model=model_instance.model, | |||
| credentials=model_instance.credentials, | |||
| price_type=PriceType.INPUT, | |||
| tokens=total_segments * 2000, | |||
| ) | |||
| return { | |||
| "total_segments": total_segments * 20, | |||
| "tokens": total_segments * 2000, | |||
| "total_price": '{:f}'.format( | |||
| text_generation_model.calc_tokens_price(total_segments * 2000, MessageType.USER)), | |||
| "currency": embedding_model.get_currency(), | |||
| "total_price": '{:f}'.format(price_info.total_amount), | |||
| "currency": price_info.currency, | |||
| "qa_preview": document_qa_list, | |||
| "preview": preview_texts | |||
| } | |||
| if embedding_model_instance: | |||
| embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_instance.model_type_instance) | |||
| embedding_price_info = embedding_model_type_instance.get_price( | |||
| model=embedding_model_instance.model, | |||
| credentials=embedding_model_instance.credentials, | |||
| price_type=PriceType.INPUT, | |||
| tokens=tokens | |||
| ) | |||
| return { | |||
| "total_segments": total_segments, | |||
| "tokens": tokens, | |||
| "total_price": '{:f}'.format(embedding_model.calc_tokens_price(tokens)) if embedding_model else 0, | |||
| "currency": embedding_model.get_currency() if embedding_model else 'USD', | |||
| "total_price": '{:f}'.format(embedding_price_info.total_amount) if embedding_model_instance else 0, | |||
| "currency": embedding_price_info.currency if embedding_model_instance else 'USD', | |||
| "preview": preview_texts | |||
| } | |||
| @@ -290,7 +319,7 @@ class IndexingRunner: | |||
| """ | |||
| Estimate the indexing for the document. | |||
| """ | |||
| embedding_model = None | |||
| embedding_model_instance = None | |||
| if dataset_id: | |||
| dataset = Dataset.query.filter_by( | |||
| id=dataset_id | |||
| @@ -298,15 +327,17 @@ class IndexingRunner: | |||
| if not dataset: | |||
| raise ValueError('Dataset not found.') | |||
| if dataset.indexing_technique == 'high_quality' or indexing_technique == 'high_quality': | |||
| embedding_model = ModelFactory.get_embedding_model( | |||
| tenant_id=dataset.tenant_id, | |||
| model_provider_name=dataset.embedding_model_provider, | |||
| model_name=dataset.embedding_model | |||
| embedding_model_instance = self.model_manager.get_model_instance( | |||
| tenant_id=tenant_id, | |||
| provider=dataset.embedding_model_provider, | |||
| model_type=ModelType.TEXT_EMBEDDING, | |||
| model=dataset.embedding_model | |||
| ) | |||
| else: | |||
| if indexing_technique == 'high_quality': | |||
| embedding_model = ModelFactory.get_embedding_model( | |||
| tenant_id=tenant_id | |||
| embedding_model_instance = self.model_manager.get_default_model_instance( | |||
| tenant_id=tenant_id, | |||
| model_type=ModelType.TEXT_EMBEDDING | |||
| ) | |||
| # load data from notion | |||
| tokens = 0 | |||
| @@ -349,35 +380,63 @@ class IndexingRunner: | |||
| processing_rule=processing_rule | |||
| ) | |||
| total_segments += len(documents) | |||
| embedding_model_type_instance = embedding_model_instance.model_type_instance | |||
| embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_type_instance) | |||
| for document in documents: | |||
| if len(preview_texts) < 5: | |||
| preview_texts.append(document.page_content) | |||
| if indexing_technique == 'high_quality' or embedding_model: | |||
| tokens += embedding_model.get_num_tokens(document.page_content) | |||
| if indexing_technique == 'high_quality' or embedding_model_instance: | |||
| tokens += embedding_model_type_instance.get_num_tokens( | |||
| model=embedding_model_instance.model, | |||
| credentials=embedding_model_instance.credentials, | |||
| texts=[document.page_content] | |||
| ) | |||
| if doc_form and doc_form == 'qa_model': | |||
| text_generation_model = ModelFactory.get_text_generation_model( | |||
| tenant_id=tenant_id | |||
| model_instance = self.model_manager.get_default_model_instance( | |||
| tenant_id=tenant_id, | |||
| model_type=ModelType.LLM | |||
| ) | |||
| model_type_instance = model_instance.model_type_instance | |||
| model_type_instance = cast(LargeLanguageModel, model_type_instance) | |||
| if len(preview_texts) > 0: | |||
| # qa model document | |||
| response = LLMGenerator.generate_qa_document(current_user.current_tenant_id, preview_texts[0], | |||
| doc_language) | |||
| document_qa_list = self.format_split_text(response) | |||
| price_info = model_type_instance.get_price( | |||
| model=model_instance.model, | |||
| credentials=model_instance.credentials, | |||
| price_type=PriceType.INPUT, | |||
| tokens=total_segments * 2000, | |||
| ) | |||
| return { | |||
| "total_segments": total_segments * 20, | |||
| "tokens": total_segments * 2000, | |||
| "total_price": '{:f}'.format( | |||
| text_generation_model.calc_tokens_price(total_segments * 2000, MessageType.USER)), | |||
| "currency": embedding_model.get_currency(), | |||
| "total_price": '{:f}'.format(price_info.total_amount), | |||
| "currency": price_info.currency, | |||
| "qa_preview": document_qa_list, | |||
| "preview": preview_texts | |||
| } | |||
| embedding_model_type_instance = embedding_model_instance.model_type_instance | |||
| embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_type_instance) | |||
| embedding_price_info = embedding_model_type_instance.get_price( | |||
| model=embedding_model_instance.model, | |||
| credentials=embedding_model_instance.credentials, | |||
| price_type=PriceType.INPUT, | |||
| tokens=tokens | |||
| ) | |||
| return { | |||
| "total_segments": total_segments, | |||
| "tokens": tokens, | |||
| "total_price": '{:f}'.format(embedding_model.calc_tokens_price(tokens)) if embedding_model else 0, | |||
| "currency": embedding_model.get_currency() if embedding_model else 'USD', | |||
| "total_price": '{:f}'.format(embedding_price_info.total_amount) if embedding_model_instance else 0, | |||
| "currency": embedding_price_info.currency if embedding_model_instance else 'USD', | |||
| "preview": preview_texts | |||
| } | |||
| @@ -656,25 +715,36 @@ class IndexingRunner: | |||
| """ | |||
| vector_index = IndexBuilder.get_index(dataset, 'high_quality') | |||
| keyword_table_index = IndexBuilder.get_index(dataset, 'economy') | |||
| embedding_model = None | |||
| embedding_model_instance = None | |||
| if dataset.indexing_technique == 'high_quality': | |||
| embedding_model = ModelFactory.get_embedding_model( | |||
| embedding_model_instance = self.model_manager.get_model_instance( | |||
| tenant_id=dataset.tenant_id, | |||
| model_provider_name=dataset.embedding_model_provider, | |||
| model_name=dataset.embedding_model | |||
| provider=dataset.embedding_model_provider, | |||
| model_type=ModelType.TEXT_EMBEDDING, | |||
| model=dataset.embedding_model | |||
| ) | |||
| # chunk nodes by chunk size | |||
| indexing_start_at = time.perf_counter() | |||
| tokens = 0 | |||
| chunk_size = 100 | |||
| embedding_model_type_instance = None | |||
| if embedding_model_instance: | |||
| embedding_model_type_instance = embedding_model_instance.model_type_instance | |||
| embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_type_instance) | |||
| for i in range(0, len(documents), chunk_size): | |||
| # check document is paused | |||
| self._check_document_paused_status(dataset_document.id) | |||
| chunk_documents = documents[i:i + chunk_size] | |||
| if dataset.indexing_technique == 'high_quality' or embedding_model: | |||
| if dataset.indexing_technique == 'high_quality' or embedding_model_type_instance: | |||
| tokens += sum( | |||
| embedding_model.get_num_tokens(document.page_content) | |||
| embedding_model_type_instance.get_num_tokens( | |||
| embedding_model_instance.model, | |||
| embedding_model_instance.credentials, | |||
| [document.page_content] | |||
| ) | |||
| for document in chunk_documents | |||
| ) | |||
| @@ -1,95 +0,0 @@ | |||
| from typing import Any, List, Dict | |||
| from langchain.memory.chat_memory import BaseChatMemory | |||
| from langchain.schema import get_buffer_string, BaseMessage | |||
| from core.file.message_file_parser import MessageFileParser | |||
| from core.model_providers.models.entity.message import PromptMessage, MessageType, to_lc_messages | |||
| from core.model_providers.models.llm.base import BaseLLM | |||
| from extensions.ext_database import db | |||
| from models.model import Conversation, Message | |||
| class ReadOnlyConversationTokenDBBufferSharedMemory(BaseChatMemory): | |||
| conversation: Conversation | |||
| human_prefix: str = "Human" | |||
| ai_prefix: str = "Assistant" | |||
| model_instance: BaseLLM | |||
| memory_key: str = "chat_history" | |||
| max_token_limit: int = 2000 | |||
| message_limit: int = 10 | |||
| @property | |||
| def buffer(self) -> List[BaseMessage]: | |||
| """String buffer of memory.""" | |||
| app_model = self.conversation.app | |||
| # fetch limited messages desc, and return reversed | |||
| messages = db.session.query(Message).filter( | |||
| Message.conversation_id == self.conversation.id, | |||
| Message.answer != '' | |||
| ).order_by(Message.created_at.desc()).limit(self.message_limit).all() | |||
| messages = list(reversed(messages)) | |||
| message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=self.conversation.app_id) | |||
| chat_messages: List[PromptMessage] = [] | |||
| for message in messages: | |||
| files = message.message_files | |||
| if files: | |||
| file_objs = message_file_parser.transform_message_files( | |||
| files, message.app_model_config | |||
| ) | |||
| prompt_message_files = [file_obj.prompt_message_file for file_obj in file_objs] | |||
| chat_messages.append(PromptMessage( | |||
| content=message.query, | |||
| type=MessageType.USER, | |||
| files=prompt_message_files | |||
| )) | |||
| else: | |||
| chat_messages.append(PromptMessage(content=message.query, type=MessageType.USER)) | |||
| chat_messages.append(PromptMessage(content=message.answer, type=MessageType.ASSISTANT)) | |||
| if not chat_messages: | |||
| return [] | |||
| # prune the chat message if it exceeds the max token limit | |||
| curr_buffer_length = self.model_instance.get_num_tokens(chat_messages) | |||
| if curr_buffer_length > self.max_token_limit: | |||
| pruned_memory = [] | |||
| while curr_buffer_length > self.max_token_limit and chat_messages: | |||
| pruned_memory.append(chat_messages.pop(0)) | |||
| curr_buffer_length = self.model_instance.get_num_tokens(chat_messages) | |||
| return to_lc_messages(chat_messages) | |||
| @property | |||
| def memory_variables(self) -> List[str]: | |||
| """Will always return list of memory variables. | |||
| :meta private: | |||
| """ | |||
| return [self.memory_key] | |||
| def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||
| """Return history buffer.""" | |||
| buffer: Any = self.buffer | |||
| if self.return_messages: | |||
| final_buffer: Any = buffer | |||
| else: | |||
| final_buffer = get_buffer_string( | |||
| buffer, | |||
| human_prefix=self.human_prefix, | |||
| ai_prefix=self.ai_prefix, | |||
| ) | |||
| return {self.memory_key: final_buffer} | |||
| def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None: | |||
| """Nothing should be saved or changed""" | |||
| pass | |||
| def clear(self) -> None: | |||
| """Nothing to clear, got a memory like a vault.""" | |||
| pass | |||
| @@ -1,36 +0,0 @@ | |||
| from typing import Any, List, Dict | |||
| from langchain.memory.chat_memory import BaseChatMemory | |||
| from langchain.schema import get_buffer_string | |||
| from core.memory.read_only_conversation_token_db_buffer_shared_memory import \ | |||
| ReadOnlyConversationTokenDBBufferSharedMemory | |||
| class ReadOnlyConversationTokenDBStringBufferSharedMemory(BaseChatMemory): | |||
| memory: ReadOnlyConversationTokenDBBufferSharedMemory | |||
| @property | |||
| def memory_variables(self) -> List[str]: | |||
| """Return memory variables.""" | |||
| return self.memory.memory_variables | |||
| def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]: | |||
| """Load memory variables from memory.""" | |||
| buffer: Any = self.memory.buffer | |||
| final_buffer = get_buffer_string( | |||
| buffer, | |||
| human_prefix=self.memory.human_prefix, | |||
| ai_prefix=self.memory.ai_prefix, | |||
| ) | |||
| return {self.memory.memory_key: final_buffer} | |||
| def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None: | |||
| """Nothing should be saved or changed""" | |||
| pass | |||
| def clear(self) -> None: | |||
| """Nothing to clear, got a memory like a vault.""" | |||
| pass | |||
| @@ -0,0 +1,109 @@ | |||
| from core.file.message_file_parser import MessageFileParser | |||
| from core.model_manager import ModelInstance | |||
| from core.model_runtime.entities.message_entities import PromptMessage, TextPromptMessageContent, UserPromptMessage, \ | |||
| AssistantPromptMessage, PromptMessageRole | |||
| from core.model_runtime.entities.model_entities import ModelType | |||
| from core.model_runtime.model_providers import model_provider_factory | |||
| from extensions.ext_database import db | |||
| from models.model import Conversation, Message | |||
| class TokenBufferMemory: | |||
| def __init__(self, conversation: Conversation, model_instance: ModelInstance) -> None: | |||
| self.conversation = conversation | |||
| self.model_instance = model_instance | |||
| def get_history_prompt_messages(self, max_token_limit: int = 2000, | |||
| message_limit: int = 10) -> list[PromptMessage]: | |||
| """ | |||
| Get history prompt messages. | |||
| :param max_token_limit: max token limit | |||
| :param message_limit: message limit | |||
| """ | |||
| app_record = self.conversation.app | |||
| # fetch limited messages, and return reversed | |||
| messages = db.session.query(Message).filter( | |||
| Message.conversation_id == self.conversation.id, | |||
| Message.answer != '' | |||
| ).order_by(Message.created_at.desc()).limit(message_limit).all() | |||
| messages = list(reversed(messages)) | |||
| message_file_parser = MessageFileParser( | |||
| tenant_id=app_record.tenant_id, | |||
| app_id=app_record.id | |||
| ) | |||
| prompt_messages = [] | |||
| for message in messages: | |||
| files = message.message_files | |||
| if files: | |||
| file_objs = message_file_parser.transform_message_files( | |||
| files, message.app_model_config | |||
| ) | |||
| prompt_message_contents = [TextPromptMessageContent(data=message.query)] | |||
| for file_obj in file_objs: | |||
| prompt_message_contents.append(file_obj.prompt_message_content) | |||
| prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) | |||
| else: | |||
| prompt_messages.append(UserPromptMessage(content=message.query)) | |||
| prompt_messages.append(AssistantPromptMessage(content=message.answer)) | |||
| if not prompt_messages: | |||
| return [] | |||
| # prune the chat message if it exceeds the max token limit | |||
| provider_instance = model_provider_factory.get_provider_instance(self.model_instance.provider) | |||
| model_type_instance = provider_instance.get_model_instance(ModelType.LLM) | |||
| curr_message_tokens = model_type_instance.get_num_tokens( | |||
| self.model_instance.model, | |||
| self.model_instance.credentials, | |||
| prompt_messages | |||
| ) | |||
| if curr_message_tokens > max_token_limit: | |||
| pruned_memory = [] | |||
| while curr_message_tokens > max_token_limit and prompt_messages: | |||
| pruned_memory.append(prompt_messages.pop(0)) | |||
| curr_message_tokens = model_type_instance.get_num_tokens( | |||
| self.model_instance.model, | |||
| self.model_instance.credentials, | |||
| prompt_messages | |||
| ) | |||
| return prompt_messages | |||
| def get_history_prompt_text(self, human_prefix: str = "Human", | |||
| ai_prefix: str = "Assistant", | |||
| max_token_limit: int = 2000, | |||
| message_limit: int = 10) -> str: | |||
| """ | |||
| Get history prompt text. | |||
| :param human_prefix: human prefix | |||
| :param ai_prefix: ai prefix | |||
| :param max_token_limit: max token limit | |||
| :param message_limit: message limit | |||
| :return: | |||
| """ | |||
| prompt_messages = self.get_history_prompt_messages( | |||
| max_token_limit=max_token_limit, | |||
| message_limit=message_limit | |||
| ) | |||
| string_messages = [] | |||
| for m in prompt_messages: | |||
| if m.role == PromptMessageRole.USER: | |||
| role = human_prefix | |||
| elif m.role == PromptMessageRole.ASSISTANT: | |||
| role = ai_prefix | |||
| else: | |||
| continue | |||
| message = f"{role}: {m.content}" | |||
| string_messages.append(message) | |||
| return "\n".join(string_messages) | |||
| @@ -0,0 +1,209 @@ | |||
| from typing import Optional, Union, Generator, cast, List, IO | |||
| from core.entities.provider_configuration import ProviderModelBundle | |||
| from core.errors.error import ProviderTokenNotInitError | |||
| from core.model_runtime.callbacks.base_callback import Callback | |||
| from core.model_runtime.entities.llm_entities import LLMResult | |||
| from core.model_runtime.entities.message_entities import PromptMessageTool, PromptMessage | |||
| from core.model_runtime.entities.model_entities import ModelType | |||
| from core.model_runtime.entities.rerank_entities import RerankResult | |||
| from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult | |||
| from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel | |||
| from core.model_runtime.model_providers.__base.moderation_model import ModerationModel | |||
| from core.model_runtime.model_providers.__base.rerank_model import RerankModel | |||
| from core.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel | |||
| from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel | |||
| from core.provider_manager import ProviderManager | |||
| class ModelInstance: | |||
| """ | |||
| Model instance class | |||
| """ | |||
| def __init__(self, provider_model_bundle: ProviderModelBundle, model: str) -> None: | |||
| self._provider_model_bundle = provider_model_bundle | |||
| self.model = model | |||
| self.provider = provider_model_bundle.configuration.provider.provider | |||
| self.credentials = self._fetch_credentials_from_bundle(provider_model_bundle, model) | |||
| self.model_type_instance = self._provider_model_bundle.model_type_instance | |||
| def _fetch_credentials_from_bundle(self, provider_model_bundle: ProviderModelBundle, model: str) -> dict: | |||
| """ | |||
| Fetch credentials from provider model bundle | |||
| :param provider_model_bundle: provider model bundle | |||
| :param model: model name | |||
| :return: | |||
| """ | |||
| credentials = provider_model_bundle.configuration.get_current_credentials( | |||
| model_type=provider_model_bundle.model_type_instance.model_type, | |||
| model=model | |||
| ) | |||
| if credentials is None: | |||
| raise ProviderTokenNotInitError(f"Model {model} credentials is not initialized.") | |||
| return credentials | |||
| def invoke_llm(self, prompt_messages: list[PromptMessage], model_parameters: Optional[dict] = None, | |||
| tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, | |||
| stream: bool = True, user: Optional[str] = None, callbacks: list[Callback] = None) \ | |||
| -> Union[LLMResult, Generator]: | |||
| """ | |||
| Invoke large language model | |||
| :param prompt_messages: prompt messages | |||
| :param model_parameters: model parameters | |||
| :param tools: tools for tool calling | |||
| :param stop: stop words | |||
| :param stream: is stream response | |||
| :param user: unique user id | |||
| :param callbacks: callbacks | |||
| :return: full response or stream response chunk generator result | |||
| """ | |||
| if not isinstance(self.model_type_instance, LargeLanguageModel): | |||
| raise Exception(f"Model type instance is not LargeLanguageModel") | |||
| self.model_type_instance = cast(LargeLanguageModel, self.model_type_instance) | |||
| return self.model_type_instance.invoke( | |||
| model=self.model, | |||
| credentials=self.credentials, | |||
| prompt_messages=prompt_messages, | |||
| model_parameters=model_parameters, | |||
| tools=tools, | |||
| stop=stop, | |||
| stream=stream, | |||
| user=user, | |||
| callbacks=callbacks | |||
| ) | |||
| def invoke_text_embedding(self, texts: list[str], user: Optional[str] = None) \ | |||
| -> TextEmbeddingResult: | |||
| """ | |||
| Invoke large language model | |||
| :param texts: texts to embed | |||
| :param user: unique user id | |||
| :return: embeddings result | |||
| """ | |||
| if not isinstance(self.model_type_instance, TextEmbeddingModel): | |||
| raise Exception(f"Model type instance is not TextEmbeddingModel") | |||
| self.model_type_instance = cast(TextEmbeddingModel, self.model_type_instance) | |||
| return self.model_type_instance.invoke( | |||
| model=self.model, | |||
| credentials=self.credentials, | |||
| texts=texts, | |||
| user=user | |||
| ) | |||
| def invoke_rerank(self, query: str, docs: list[str], score_threshold: Optional[float] = None, top_n: Optional[int] = None, | |||
| user: Optional[str] = None) \ | |||
| -> RerankResult: | |||
| """ | |||
| Invoke rerank model | |||
| :param query: search query | |||
| :param docs: docs for reranking | |||
| :param score_threshold: score threshold | |||
| :param top_n: top n | |||
| :param user: unique user id | |||
| :return: rerank result | |||
| """ | |||
| if not isinstance(self.model_type_instance, RerankModel): | |||
| raise Exception(f"Model type instance is not RerankModel") | |||
| self.model_type_instance = cast(RerankModel, self.model_type_instance) | |||
| return self.model_type_instance.invoke( | |||
| model=self.model, | |||
| credentials=self.credentials, | |||
| query=query, | |||
| docs=docs, | |||
| score_threshold=score_threshold, | |||
| top_n=top_n, | |||
| user=user | |||
| ) | |||
| def invoke_moderation(self, text: str, user: Optional[str] = None) \ | |||
| -> bool: | |||
| """ | |||
| Invoke moderation model | |||
| :param text: text to moderate | |||
| :param user: unique user id | |||
| :return: false if text is safe, true otherwise | |||
| """ | |||
| if not isinstance(self.model_type_instance, ModerationModel): | |||
| raise Exception(f"Model type instance is not ModerationModel") | |||
| self.model_type_instance = cast(ModerationModel, self.model_type_instance) | |||
| return self.model_type_instance.invoke( | |||
| model=self.model, | |||
| credentials=self.credentials, | |||
| text=text, | |||
| user=user | |||
| ) | |||
| def invoke_speech2text(self, file: IO[bytes], user: Optional[str] = None) \ | |||
| -> str: | |||
| """ | |||
| Invoke large language model | |||
| :param file: audio file | |||
| :param user: unique user id | |||
| :return: text for given audio file | |||
| """ | |||
| if not isinstance(self.model_type_instance, Speech2TextModel): | |||
| raise Exception(f"Model type instance is not Speech2TextModel") | |||
| self.model_type_instance = cast(Speech2TextModel, self.model_type_instance) | |||
| return self.model_type_instance.invoke( | |||
| model=self.model, | |||
| credentials=self.credentials, | |||
| file=file, | |||
| user=user | |||
| ) | |||
| class ModelManager: | |||
| def __init__(self) -> None: | |||
| self._provider_manager = ProviderManager() | |||
| def get_model_instance(self, tenant_id: str, provider: str, model_type: ModelType, model: str) -> ModelInstance: | |||
| """ | |||
| Get model instance | |||
| :param tenant_id: tenant id | |||
| :param provider: provider name | |||
| :param model_type: model type | |||
| :param model: model name | |||
| :return: | |||
| """ | |||
| provider_model_bundle = self._provider_manager.get_provider_model_bundle( | |||
| tenant_id=tenant_id, | |||
| provider=provider, | |||
| model_type=model_type | |||
| ) | |||
| return ModelInstance(provider_model_bundle, model) | |||
| def get_default_model_instance(self, tenant_id: str, model_type: ModelType) -> ModelInstance: | |||
| """ | |||
| Get default model instance | |||
| :param tenant_id: tenant id | |||
| :param model_type: model type | |||
| :return: | |||
| """ | |||
| default_model_entity = self._provider_manager.get_default_model( | |||
| tenant_id=tenant_id, | |||
| model_type=model_type | |||
| ) | |||
| if not default_model_entity: | |||
| raise ProviderTokenNotInitError(f"Default model not found for {model_type}") | |||
| return self.get_model_instance( | |||
| tenant_id=tenant_id, | |||
| provider=default_model_entity.provider.provider, | |||
| model_type=model_type, | |||
| model=default_model_entity.model | |||
| ) | |||
| @@ -1,335 +0,0 @@ | |||
| from typing import Optional | |||
| from langchain.callbacks.base import Callbacks | |||
| from core.model_providers.error import ProviderTokenNotInitError, LLMBadRequestError | |||
| from core.model_providers.model_provider_factory import ModelProviderFactory, DEFAULT_MODELS | |||
| from core.model_providers.models.base import BaseProviderModel | |||
| from core.model_providers.models.embedding.base import BaseEmbedding | |||
| from core.model_providers.models.entity.model_params import ModelKwargs, ModelType | |||
| from core.model_providers.models.llm.base import BaseLLM | |||
| from core.model_providers.models.moderation.base import BaseModeration | |||
| from core.model_providers.models.reranking.base import BaseReranking | |||
| from core.model_providers.models.speech2text.base import BaseSpeech2Text | |||
| from extensions.ext_database import db | |||
| from models.provider import TenantDefaultModel | |||
| class ModelFactory: | |||
| @classmethod | |||
| def get_text_generation_model_from_model_config(cls, tenant_id: str, | |||
| model_config: dict, | |||
| streaming: bool = False, | |||
| callbacks: Callbacks = None) -> Optional[BaseLLM]: | |||
| provider_name = model_config.get("provider") | |||
| model_name = model_config.get("name") | |||
| completion_params = model_config.get("completion_params", {}) | |||
| return cls.get_text_generation_model( | |||
| tenant_id=tenant_id, | |||
| model_provider_name=provider_name, | |||
| model_name=model_name, | |||
| model_kwargs=ModelKwargs( | |||
| temperature=completion_params.get('temperature', 0), | |||
| max_tokens=completion_params.get('max_tokens', 256), | |||
| top_p=completion_params.get('top_p', 0), | |||
| frequency_penalty=completion_params.get('frequency_penalty', 0.1), | |||
| presence_penalty=completion_params.get('presence_penalty', 0.1) | |||
| ), | |||
| streaming=streaming, | |||
| callbacks=callbacks | |||
| ) | |||
| @classmethod | |||
| def get_text_generation_model(cls, | |||
| tenant_id: str, | |||
| model_provider_name: Optional[str] = None, | |||
| model_name: Optional[str] = None, | |||
| model_kwargs: Optional[ModelKwargs] = None, | |||
| streaming: bool = False, | |||
| callbacks: Callbacks = None, | |||
| deduct_quota: bool = True) -> Optional[BaseLLM]: | |||
| """ | |||
| get text generation model. | |||
| :param tenant_id: a string representing the ID of the tenant. | |||
| :param model_provider_name: | |||
| :param model_name: | |||
| :param model_kwargs: | |||
| :param streaming: | |||
| :param callbacks: | |||
| :param deduct_quota: | |||
| :return: | |||
| """ | |||
| is_default_model = False | |||
| if model_provider_name is None and model_name is None: | |||
| default_model = cls.get_default_model(tenant_id, ModelType.TEXT_GENERATION) | |||
| if not default_model: | |||
| raise LLMBadRequestError(f"Default model is not available. " | |||
| f"Please configure a Default System Reasoning Model " | |||
| f"in the Settings -> Model Provider.") | |||
| model_provider_name = default_model.provider_name | |||
| model_name = default_model.model_name | |||
| is_default_model = True | |||
| # get model provider | |||
| model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name) | |||
| if not model_provider: | |||
| raise ProviderTokenNotInitError(f"Model {model_name} provider credentials is not initialized.") | |||
| # init text generation model | |||
| model_class = model_provider.get_model_class(model_type=ModelType.TEXT_GENERATION) | |||
| try: | |||
| model_instance = model_class( | |||
| model_provider=model_provider, | |||
| name=model_name, | |||
| model_kwargs=model_kwargs, | |||
| streaming=streaming, | |||
| callbacks=callbacks | |||
| ) | |||
| except LLMBadRequestError as e: | |||
| if is_default_model: | |||
| raise LLMBadRequestError(f"Default model {model_name} is not available. " | |||
| f"Please check your model provider credentials.") | |||
| else: | |||
| raise e | |||
| if is_default_model or not deduct_quota: | |||
| model_instance.deduct_quota = False | |||
| return model_instance | |||
| @classmethod | |||
| def get_embedding_model(cls, | |||
| tenant_id: str, | |||
| model_provider_name: Optional[str] = None, | |||
| model_name: Optional[str] = None) -> Optional[BaseEmbedding]: | |||
| """ | |||
| get embedding model. | |||
| :param tenant_id: a string representing the ID of the tenant. | |||
| :param model_provider_name: | |||
| :param model_name: | |||
| :return: | |||
| """ | |||
| if model_provider_name is None and model_name is None: | |||
| default_model = cls.get_default_model(tenant_id, ModelType.EMBEDDINGS) | |||
| if not default_model: | |||
| raise LLMBadRequestError(f"Default model is not available. " | |||
| f"Please configure a Default Embedding Model " | |||
| f"in the Settings -> Model Provider.") | |||
| model_provider_name = default_model.provider_name | |||
| model_name = default_model.model_name | |||
| # get model provider | |||
| model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name) | |||
| if not model_provider: | |||
| raise ProviderTokenNotInitError(f"Model {model_name} provider credentials is not initialized.") | |||
| # init embedding model | |||
| model_class = model_provider.get_model_class(model_type=ModelType.EMBEDDINGS) | |||
| return model_class( | |||
| model_provider=model_provider, | |||
| name=model_name | |||
| ) | |||
| @classmethod | |||
| def get_reranking_model(cls, | |||
| tenant_id: str, | |||
| model_provider_name: Optional[str] = None, | |||
| model_name: Optional[str] = None) -> Optional[BaseReranking]: | |||
| """ | |||
| get reranking model. | |||
| :param tenant_id: a string representing the ID of the tenant. | |||
| :param model_provider_name: | |||
| :param model_name: | |||
| :return: | |||
| """ | |||
| if (model_provider_name is None or len(model_provider_name) == 0) and (model_name is None or len(model_name) == 0): | |||
| default_model = cls.get_default_model(tenant_id, ModelType.RERANKING) | |||
| if not default_model: | |||
| raise LLMBadRequestError(f"Default model is not available. " | |||
| f"Please configure a Default Reranking Model " | |||
| f"in the Settings -> Model Provider.") | |||
| model_provider_name = default_model.provider_name | |||
| model_name = default_model.model_name | |||
| # get model provider | |||
| model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name) | |||
| if not model_provider: | |||
| raise ProviderTokenNotInitError(f"Model {model_name} provider credentials is not initialized.") | |||
| # init reranking model | |||
| model_class = model_provider.get_model_class(model_type=ModelType.RERANKING) | |||
| return model_class( | |||
| model_provider=model_provider, | |||
| name=model_name | |||
| ) | |||
| @classmethod | |||
| def get_speech2text_model(cls, | |||
| tenant_id: str, | |||
| model_provider_name: Optional[str] = None, | |||
| model_name: Optional[str] = None) -> Optional[BaseSpeech2Text]: | |||
| """ | |||
| get speech to text model. | |||
| :param tenant_id: a string representing the ID of the tenant. | |||
| :param model_provider_name: | |||
| :param model_name: | |||
| :return: | |||
| """ | |||
| if model_provider_name is None and model_name is None: | |||
| default_model = cls.get_default_model(tenant_id, ModelType.SPEECH_TO_TEXT) | |||
| if not default_model: | |||
| raise LLMBadRequestError(f"Default model is not available. " | |||
| f"Please configure a Default Speech-to-Text Model " | |||
| f"in the Settings -> Model Provider.") | |||
| model_provider_name = default_model.provider_name | |||
| model_name = default_model.model_name | |||
| # get model provider | |||
| model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name) | |||
| if not model_provider: | |||
| raise ProviderTokenNotInitError(f"Model {model_name} provider credentials is not initialized.") | |||
| # init speech to text model | |||
| model_class = model_provider.get_model_class(model_type=ModelType.SPEECH_TO_TEXT) | |||
| return model_class( | |||
| model_provider=model_provider, | |||
| name=model_name | |||
| ) | |||
| @classmethod | |||
| def get_moderation_model(cls, | |||
| tenant_id: str, | |||
| model_provider_name: str, | |||
| model_name: str) -> Optional[BaseModeration]: | |||
| """ | |||
| get moderation model. | |||
| :param tenant_id: a string representing the ID of the tenant. | |||
| :param model_provider_name: | |||
| :param model_name: | |||
| :return: | |||
| """ | |||
| # get model provider | |||
| model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name) | |||
| if not model_provider: | |||
| raise ProviderTokenNotInitError(f"Model {model_name} provider credentials is not initialized.") | |||
| # init moderation model | |||
| model_class = model_provider.get_model_class(model_type=ModelType.MODERATION) | |||
| return model_class( | |||
| model_provider=model_provider, | |||
| name=model_name | |||
| ) | |||
| @classmethod | |||
| def get_default_model(cls, tenant_id: str, model_type: ModelType) -> TenantDefaultModel: | |||
| """ | |||
| get default model of model type. | |||
| :param tenant_id: | |||
| :param model_type: | |||
| :return: | |||
| """ | |||
| # get default model | |||
| default_model = db.session.query(TenantDefaultModel) \ | |||
| .filter( | |||
| TenantDefaultModel.tenant_id == tenant_id, | |||
| TenantDefaultModel.model_type == model_type.value | |||
| ).first() | |||
| if not default_model: | |||
| model_provider_rules = ModelProviderFactory.get_provider_rules() | |||
| for model_provider_name, model_provider_rule in model_provider_rules.items(): | |||
| model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name) | |||
| if not model_provider: | |||
| continue | |||
| model_list = model_provider.get_supported_model_list(model_type) | |||
| if model_list: | |||
| model_info = model_list[0] | |||
| default_model = TenantDefaultModel( | |||
| tenant_id=tenant_id, | |||
| model_type=model_type.value, | |||
| provider_name=model_provider_name, | |||
| model_name=model_info['id'] | |||
| ) | |||
| db.session.add(default_model) | |||
| db.session.commit() | |||
| break | |||
| return default_model | |||
| @classmethod | |||
| def update_default_model(cls, | |||
| tenant_id: str, | |||
| model_type: ModelType, | |||
| provider_name: str, | |||
| model_name: str) -> TenantDefaultModel: | |||
| """ | |||
| update default model of model type. | |||
| :param tenant_id: | |||
| :param model_type: | |||
| :param provider_name: | |||
| :param model_name: | |||
| :return: | |||
| """ | |||
| model_provider_name = ModelProviderFactory.get_provider_names() | |||
| if provider_name not in model_provider_name: | |||
| raise ValueError(f'Invalid provider name: {provider_name}') | |||
| model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, provider_name) | |||
| if not model_provider: | |||
| raise ProviderTokenNotInitError(f"Model {model_name} provider credentials is not initialized.") | |||
| model_list = model_provider.get_supported_model_list(model_type) | |||
| model_ids = [model['id'] for model in model_list] | |||
| if model_name not in model_ids: | |||
| raise ValueError(f'Invalid model name: {model_name}') | |||
| # get default model | |||
| default_model = db.session.query(TenantDefaultModel) \ | |||
| .filter( | |||
| TenantDefaultModel.tenant_id == tenant_id, | |||
| TenantDefaultModel.model_type == model_type.value | |||
| ).first() | |||
| if default_model: | |||
| # update default model | |||
| default_model.provider_name = provider_name | |||
| default_model.model_name = model_name | |||
| db.session.commit() | |||
| else: | |||
| # create default model | |||
| default_model = TenantDefaultModel( | |||
| tenant_id=tenant_id, | |||
| model_type=model_type.value, | |||
| provider_name=provider_name, | |||
| model_name=model_name, | |||
| ) | |||
| db.session.add(default_model) | |||
| db.session.commit() | |||
| return default_model | |||