| def main(): | def main(): | ||||
| has_chinese = False | has_chinese = False | ||||
| excluded_files = ["model_template.py", 'stopwords.py', 'commands.py', 'indexing_runner.py', 'web_reader_tool.py'] | |||||
| excluded_files = ["model_template.py", 'stopwords.py', 'commands.py', | |||||
| 'indexing_runner.py', 'web_reader_tool.py', 'spark_provider.py'] | |||||
| for root, _, files in os.walk("."): | for root, _, files in os.walk("."): | ||||
| for file in files: | for file in files: |
| NOTION_CLIENT_SECRET=you-client-secret | NOTION_CLIENT_SECRET=you-client-secret | ||||
| NOTION_CLIENT_ID=you-client-id | NOTION_CLIENT_ID=you-client-id | ||||
| NOTION_INTERNAL_SECRET=you-internal-secret | NOTION_INTERNAL_SECRET=you-internal-secret | ||||
| # Hosted Model Credentials | |||||
| HOSTED_OPENAI_ENABLED=false | |||||
| HOSTED_OPENAI_API_KEY= | |||||
| HOSTED_OPENAI_API_BASE= | |||||
| HOSTED_OPENAI_API_ORGANIZATION= | |||||
| HOSTED_OPENAI_QUOTA_LIMIT=200 | |||||
| HOSTED_OPENAI_PAID_ENABLED=false | |||||
| HOSTED_OPENAI_PAID_STRIPE_PRICE_ID= | |||||
| HOSTED_OPENAI_PAID_INCREASE_QUOTA=1 | |||||
| HOSTED_AZURE_OPENAI_ENABLED=false | |||||
| HOSTED_AZURE_OPENAI_API_KEY= | |||||
| HOSTED_AZURE_OPENAI_API_BASE= | |||||
| HOSTED_AZURE_OPENAI_QUOTA_LIMIT=200 | |||||
| HOSTED_ANTHROPIC_ENABLED=false | |||||
| HOSTED_ANTHROPIC_API_BASE= | |||||
| HOSTED_ANTHROPIC_API_KEY= | |||||
| HOSTED_ANTHROPIC_QUOTA_LIMIT=1000000 | |||||
| HOSTED_ANTHROPIC_PAID_ENABLED=false | |||||
| HOSTED_ANTHROPIC_PAID_STRIPE_PRICE_ID= | |||||
| HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA=1 | |||||
| STRIPE_API_KEY= | |||||
| STRIPE_WEBHOOK_SECRET= |
| import flask_login | import flask_login | ||||
| from flask_cors import CORS | from flask_cors import CORS | ||||
| from core.model_providers.providers import hosted | |||||
| from extensions import ext_session, ext_celery, ext_sentry, ext_redis, ext_login, ext_migrate, \ | from extensions import ext_session, ext_celery, ext_sentry, ext_redis, ext_login, ext_migrate, \ | ||||
| ext_database, ext_storage, ext_mail | |||||
| ext_database, ext_storage, ext_mail, ext_stripe | |||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| from extensions.ext_login import login_manager | from extensions.ext_login import login_manager | ||||
| register_blueprints(app) | register_blueprints(app) | ||||
| register_commands(app) | register_commands(app) | ||||
| core.init_app(app) | |||||
| hosted.init_app(app) | |||||
| return app | return app | ||||
| ext_login.init_app(app) | ext_login.init_app(app) | ||||
| ext_mail.init_app(app) | ext_mail.init_app(app) | ||||
| ext_sentry.init_app(app) | ext_sentry.init_app(app) | ||||
| ext_stripe.init_app(app) | |||||
| def _create_tenant_for_account(account): | def _create_tenant_for_account(account): | ||||
| } | } | ||||
| @app.route('/db-pool-stat') | |||||
| def pool_stat(): | |||||
| engine = db.engine | |||||
| return { | |||||
| 'pool_size': engine.pool.size(), | |||||
| 'checked_in_connections': engine.pool.checkedin(), | |||||
| 'checked_out_connections': engine.pool.checkedout(), | |||||
| 'overflow_connections': engine.pool.overflow(), | |||||
| 'connection_timeout': engine.pool.timeout(), | |||||
| 'recycle_time': db.engine.pool._recycle | |||||
| } | |||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||
| app.run(host='0.0.0.0', port=5001) | app.run(host='0.0.0.0', port=5001) |
| import datetime | import datetime | ||||
| import logging | |||||
| import math | |||||
| import random | import random | ||||
| import string | import string | ||||
| import time | import time | ||||
| from werkzeug.exceptions import NotFound | from werkzeug.exceptions import NotFound | ||||
| from core.index.index import IndexBuilder | from core.index.index import IndexBuilder | ||||
| from core.model_providers.providers.hosted import hosted_model_providers | |||||
| from libs.password import password_pattern, valid_password, hash_password | from libs.password import password_pattern, valid_password, hash_password | ||||
| from libs.helper import email as email_validate | from libs.helper import email as email_validate | ||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| from libs.rsa import generate_key_pair | from libs.rsa import generate_key_pair | ||||
| from models.account import InvitationCode, Tenant | from models.account import InvitationCode, Tenant | ||||
| from models.dataset import Dataset, DatasetQuery, Document, DocumentSegment | |||||
| from models.dataset import Dataset, DatasetQuery, Document | |||||
| from models.model import Account | from models.model import Account | ||||
| import secrets | import secrets | ||||
| import base64 | import base64 | ||||
| from models.provider import Provider, ProviderName | |||||
| from services.provider_service import ProviderService | |||||
| from models.provider import Provider, ProviderType, ProviderQuotaType | |||||
| @click.command('reset-password', help='Reset the account password.') | @click.command('reset-password', help='Reset the account password.') | ||||
| @click.command('sync-anthropic-hosted-providers', help='Sync anthropic hosted providers.') | @click.command('sync-anthropic-hosted-providers', help='Sync anthropic hosted providers.') | ||||
| def sync_anthropic_hosted_providers(): | def sync_anthropic_hosted_providers(): | ||||
| if not hosted_model_providers.anthropic: | |||||
| click.echo(click.style('Anthropic hosted provider is not configured.', fg='red')) | |||||
| return | |||||
| click.echo(click.style('Start sync anthropic hosted providers.', fg='green')) | click.echo(click.style('Start sync anthropic hosted providers.', fg='green')) | ||||
| count = 0 | count = 0 | ||||
| page = 1 | page = 1 | ||||
| while True: | while True: | ||||
| try: | try: | ||||
| tenants = db.session.query(Tenant).order_by(Tenant.created_at.desc()).paginate(page=page, per_page=50) | |||||
| providers = db.session.query(Provider).filter( | |||||
| Provider.provider_name == 'anthropic', | |||||
| Provider.provider_type == ProviderType.SYSTEM.value, | |||||
| Provider.quota_type == ProviderQuotaType.TRIAL.value, | |||||
| ).order_by(Provider.created_at.desc()).paginate(page=page, per_page=100) | |||||
| except NotFound: | except NotFound: | ||||
| break | break | ||||
| page += 1 | page += 1 | ||||
| for tenant in tenants: | |||||
| for provider in providers: | |||||
| try: | try: | ||||
| click.echo('Syncing tenant anthropic hosted provider: {}'.format(tenant.id)) | |||||
| ProviderService.create_system_provider( | |||||
| tenant, | |||||
| ProviderName.ANTHROPIC.value, | |||||
| current_app.config['ANTHROPIC_HOSTED_QUOTA_LIMIT'], | |||||
| True | |||||
| ) | |||||
| click.echo('Syncing tenant anthropic hosted provider: {}'.format(provider.tenant_id)) | |||||
| original_quota_limit = provider.quota_limit | |||||
| new_quota_limit = hosted_model_providers.anthropic.quota_limit | |||||
| division = math.ceil(new_quota_limit / 1000) | |||||
| provider.quota_limit = new_quota_limit if original_quota_limit == 1000 \ | |||||
| else original_quota_limit * division | |||||
| provider.quota_used = division * provider.quota_used | |||||
| db.session.commit() | |||||
| count += 1 | count += 1 | ||||
| except Exception as e: | except Exception as e: | ||||
| click.echo(click.style( | click.echo(click.style( |
| 'SESSION_USE_SIGNER': 'True', | 'SESSION_USE_SIGNER': 'True', | ||||
| 'DEPLOY_ENV': 'PRODUCTION', | 'DEPLOY_ENV': 'PRODUCTION', | ||||
| 'SQLALCHEMY_POOL_SIZE': 30, | 'SQLALCHEMY_POOL_SIZE': 30, | ||||
| 'SQLALCHEMY_POOL_RECYCLE': 3600, | |||||
| 'SQLALCHEMY_ECHO': 'False', | 'SQLALCHEMY_ECHO': 'False', | ||||
| 'SENTRY_TRACES_SAMPLE_RATE': 1.0, | 'SENTRY_TRACES_SAMPLE_RATE': 1.0, | ||||
| 'SENTRY_PROFILES_SAMPLE_RATE': 1.0, | 'SENTRY_PROFILES_SAMPLE_RATE': 1.0, | ||||
| 'PDF_PREVIEW': 'True', | 'PDF_PREVIEW': 'True', | ||||
| 'LOG_LEVEL': 'INFO', | 'LOG_LEVEL': 'INFO', | ||||
| 'DISABLE_PROVIDER_CONFIG_VALIDATION': 'False', | 'DISABLE_PROVIDER_CONFIG_VALIDATION': 'False', | ||||
| 'DEFAULT_LLM_PROVIDER': 'openai', | |||||
| 'OPENAI_HOSTED_QUOTA_LIMIT': 200, | |||||
| 'ANTHROPIC_HOSTED_QUOTA_LIMIT': 1000, | |||||
| 'HOSTED_OPENAI_QUOTA_LIMIT': 200, | |||||
| 'HOSTED_OPENAI_ENABLED': 'False', | |||||
| 'HOSTED_OPENAI_PAID_ENABLED': 'False', | |||||
| 'HOSTED_OPENAI_PAID_INCREASE_QUOTA': 1, | |||||
| 'HOSTED_AZURE_OPENAI_ENABLED': 'False', | |||||
| 'HOSTED_AZURE_OPENAI_QUOTA_LIMIT': 200, | |||||
| 'HOSTED_ANTHROPIC_QUOTA_LIMIT': 1000000, | |||||
| 'HOSTED_ANTHROPIC_ENABLED': 'False', | |||||
| 'HOSTED_ANTHROPIC_PAID_ENABLED': 'False', | |||||
| 'HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA': 1, | |||||
| 'TENANT_DOCUMENT_COUNT': 100, | 'TENANT_DOCUMENT_COUNT': 100, | ||||
| 'CLEAN_DAY_SETTING': 30 | 'CLEAN_DAY_SETTING': 30 | ||||
| } | } | ||||
| } | } | ||||
| self.SQLALCHEMY_DATABASE_URI = f"postgresql://{db_credentials['DB_USERNAME']}:{db_credentials['DB_PASSWORD']}@{db_credentials['DB_HOST']}:{db_credentials['DB_PORT']}/{db_credentials['DB_DATABASE']}" | self.SQLALCHEMY_DATABASE_URI = f"postgresql://{db_credentials['DB_USERNAME']}:{db_credentials['DB_PASSWORD']}@{db_credentials['DB_HOST']}:{db_credentials['DB_PORT']}/{db_credentials['DB_DATABASE']}" | ||||
| self.SQLALCHEMY_ENGINE_OPTIONS = {'pool_size': int(get_env('SQLALCHEMY_POOL_SIZE'))} | |||||
| self.SQLALCHEMY_ENGINE_OPTIONS = { | |||||
| 'pool_size': int(get_env('SQLALCHEMY_POOL_SIZE')), | |||||
| 'pool_recycle': int(get_env('SQLALCHEMY_POOL_RECYCLE')) | |||||
| } | |||||
| self.SQLALCHEMY_ECHO = get_bool_env('SQLALCHEMY_ECHO') | self.SQLALCHEMY_ECHO = get_bool_env('SQLALCHEMY_ECHO') | ||||
| self.BROKER_USE_SSL = self.CELERY_BROKER_URL.startswith('rediss://') | self.BROKER_USE_SSL = self.CELERY_BROKER_URL.startswith('rediss://') | ||||
| # hosted provider credentials | # hosted provider credentials | ||||
| self.OPENAI_API_KEY = get_env('OPENAI_API_KEY') | |||||
| self.ANTHROPIC_API_KEY = get_env('ANTHROPIC_API_KEY') | |||||
| self.OPENAI_HOSTED_QUOTA_LIMIT = get_env('OPENAI_HOSTED_QUOTA_LIMIT') | |||||
| self.ANTHROPIC_HOSTED_QUOTA_LIMIT = get_env('ANTHROPIC_HOSTED_QUOTA_LIMIT') | |||||
| self.HOSTED_OPENAI_ENABLED = get_bool_env('HOSTED_OPENAI_ENABLED') | |||||
| self.HOSTED_OPENAI_API_KEY = get_env('HOSTED_OPENAI_API_KEY') | |||||
| self.HOSTED_OPENAI_API_BASE = get_env('HOSTED_OPENAI_API_BASE') | |||||
| self.HOSTED_OPENAI_API_ORGANIZATION = get_env('HOSTED_OPENAI_API_ORGANIZATION') | |||||
| self.HOSTED_OPENAI_QUOTA_LIMIT = get_env('HOSTED_OPENAI_QUOTA_LIMIT') | |||||
| self.HOSTED_OPENAI_PAID_ENABLED = get_bool_env('HOSTED_OPENAI_PAID_ENABLED') | |||||
| self.HOSTED_OPENAI_PAID_STRIPE_PRICE_ID = get_env('HOSTED_OPENAI_PAID_STRIPE_PRICE_ID') | |||||
| self.HOSTED_OPENAI_PAID_INCREASE_QUOTA = int(get_env('HOSTED_OPENAI_PAID_INCREASE_QUOTA')) | |||||
| self.HOSTED_AZURE_OPENAI_ENABLED = get_bool_env('HOSTED_AZURE_OPENAI_ENABLED') | |||||
| self.HOSTED_AZURE_OPENAI_API_KEY = get_env('HOSTED_AZURE_OPENAI_API_KEY') | |||||
| self.HOSTED_AZURE_OPENAI_API_BASE = get_env('HOSTED_AZURE_OPENAI_API_BASE') | |||||
| self.HOSTED_AZURE_OPENAI_QUOTA_LIMIT = get_env('HOSTED_AZURE_OPENAI_QUOTA_LIMIT') | |||||
| self.HOSTED_ANTHROPIC_ENABLED = get_bool_env('HOSTED_ANTHROPIC_ENABLED') | |||||
| self.HOSTED_ANTHROPIC_API_BASE = get_env('HOSTED_ANTHROPIC_API_BASE') | |||||
| self.HOSTED_ANTHROPIC_API_KEY = get_env('HOSTED_ANTHROPIC_API_KEY') | |||||
| self.HOSTED_ANTHROPIC_QUOTA_LIMIT = get_env('HOSTED_ANTHROPIC_QUOTA_LIMIT') | |||||
| self.HOSTED_ANTHROPIC_PAID_ENABLED = get_bool_env('HOSTED_ANTHROPIC_PAID_ENABLED') | |||||
| self.HOSTED_ANTHROPIC_PAID_STRIPE_PRICE_ID = get_env('HOSTED_ANTHROPIC_PAID_STRIPE_PRICE_ID') | |||||
| self.HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA = get_env('HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA') | |||||
| self.STRIPE_API_KEY = get_env('STRIPE_API_KEY') | |||||
| self.STRIPE_WEBHOOK_SECRET = get_env('STRIPE_WEBHOOK_SECRET') | |||||
| # By default it is False | # By default it is False | ||||
| # You could disable it for compatibility with certain OpenAPI providers | # You could disable it for compatibility with certain OpenAPI providers | ||||
| self.DISABLE_PROVIDER_CONFIG_VALIDATION = get_bool_env('DISABLE_PROVIDER_CONFIG_VALIDATION') | self.DISABLE_PROVIDER_CONFIG_VALIDATION = get_bool_env('DISABLE_PROVIDER_CONFIG_VALIDATION') | ||||
| # For temp use only | |||||
| # set default LLM provider, default is 'openai', support `azure_openai` | |||||
| self.DEFAULT_LLM_PROVIDER = get_env('DEFAULT_LLM_PROVIDER') | |||||
| # notion import setting | # notion import setting | ||||
| self.NOTION_CLIENT_ID = get_env('NOTION_CLIENT_ID') | self.NOTION_CLIENT_ID = get_env('NOTION_CLIENT_ID') | ||||
| self.NOTION_CLIENT_SECRET = get_env('NOTION_CLIENT_SECRET') | self.NOTION_CLIENT_SECRET = get_env('NOTION_CLIENT_SECRET') |
| from .datasets import datasets, datasets_document, datasets_segments, file, hit_testing, data_source | from .datasets import datasets, datasets_document, datasets_segments, file, hit_testing, data_source | ||||
| # Import workspace controllers | # Import workspace controllers | ||||
| from .workspace import workspace, members, model_providers, account, tool_providers | |||||
| from .workspace import workspace, members, providers, model_providers, account, tool_providers, models | |||||
| # Import explore controllers | # Import explore controllers | ||||
| from .explore import installed_app, recommended_app, completion, conversation, message, parameter, saved_message, audio | from .explore import installed_app, recommended_app, completion, conversation, message, parameter, saved_message, audio | ||||
| # Import universal chat controllers | # Import universal chat controllers | ||||
| from .universal_chat import chat, conversation, message, parameter, audio | from .universal_chat import chat, conversation, message, parameter, audio | ||||
| # Import webhook controllers | |||||
| from .webhook import stripe |
| import json | import json | ||||
| from datetime import datetime | from datetime import datetime | ||||
| import flask | |||||
| from flask_login import login_required, current_user | from flask_login import login_required, current_user | ||||
| from flask_restful import Resource, reqparse, fields, marshal_with, abort, inputs | from flask_restful import Resource, reqparse, fields, marshal_with, abort, inputs | ||||
| from werkzeug.exceptions import Unauthorized, Forbidden | |||||
| from werkzeug.exceptions import Forbidden | |||||
| from constants.model_template import model_templates, demo_model_templates | from constants.model_template import model_templates, demo_model_templates | ||||
| from controllers.console import api | from controllers.console import api | ||||
| from controllers.console.app.error import AppNotFoundError | |||||
| from controllers.console.app.error import AppNotFoundError, ProviderNotInitializeError | |||||
| from controllers.console.setup import setup_required | from controllers.console.setup import setup_required | ||||
| from controllers.console.wraps import account_initialization_required | from controllers.console.wraps import account_initialization_required | ||||
| from core.model_providers.model_factory import ModelFactory | |||||
| from core.model_providers.models.entity.model_params import ModelType | |||||
| from events.app_event import app_was_created, app_was_deleted | from events.app_event import app_was_created, app_was_deleted | ||||
| from libs.helper import TimestampField | from libs.helper import TimestampField | ||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| if args['model_config'] is not None: | if args['model_config'] is not None: | ||||
| # validate config | # validate config | ||||
| model_configuration = AppModelConfigService.validate_configuration( | model_configuration = AppModelConfigService.validate_configuration( | ||||
| tenant_id=current_user.current_tenant_id, | |||||
| account=current_user, | account=current_user, | ||||
| config=args['model_config'], | |||||
| mode=args['mode'] | |||||
| config=args['model_config'] | |||||
| ) | ) | ||||
| app = App( | app = App( | ||||
| app = App(**model_config_template['app']) | app = App(**model_config_template['app']) | ||||
| app_model_config = AppModelConfig(**model_config_template['model_config']) | app_model_config = AppModelConfig(**model_config_template['model_config']) | ||||
| default_model = ModelFactory.get_default_model( | |||||
| tenant_id=current_user.current_tenant_id, | |||||
| model_type=ModelType.TEXT_GENERATION | |||||
| ) | |||||
| if default_model: | |||||
| model_dict = app_model_config.model_dict | |||||
| model_dict['provider'] = default_model.provider_name | |||||
| model_dict['name'] = default_model.model_name | |||||
| app_model_config.model = json.dumps(model_dict) | |||||
| else: | |||||
| raise ProviderNotInitializeError( | |||||
| f"No Text Generation Model available. Please configure a valid provider " | |||||
| f"in the Settings -> Model Provider.") | |||||
| app.name = args['name'] | app.name = args['name'] | ||||
| app.mode = args['mode'] | app.mode = args['mode'] | ||||
| app.icon = args['icon'] | app.icon = args['icon'] |
| UnsupportedAudioTypeError, ProviderNotSupportSpeechToTextError | UnsupportedAudioTypeError, ProviderNotSupportSpeechToTextError | ||||
| from controllers.console.setup import setup_required | from controllers.console.setup import setup_required | ||||
| from controllers.console.wraps import account_initialization_required | from controllers.console.wraps import account_initialization_required | ||||
| from core.llm.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \ | |||||
| from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \ | |||||
| LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError | LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError | ||||
| from flask_restful import Resource | from flask_restful import Resource | ||||
| from services.audio_service import AudioService | from services.audio_service import AudioService |
| from controllers.console.setup import setup_required | from controllers.console.setup import setup_required | ||||
| from controllers.console.wraps import account_initialization_required | from controllers.console.wraps import account_initialization_required | ||||
| from core.conversation_message_task import PubHandler | from core.conversation_message_task import PubHandler | ||||
| from core.llm.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \ | |||||
| from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \ | |||||
| LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError | LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError | ||||
| from libs.helper import uuid_value | from libs.helper import uuid_value | ||||
| from flask_restful import Resource, reqparse | from flask_restful import Resource, reqparse | ||||
| parser.add_argument('inputs', type=dict, required=True, location='json') | parser.add_argument('inputs', type=dict, required=True, location='json') | ||||
| parser.add_argument('query', type=str, location='json') | parser.add_argument('query', type=str, location='json') | ||||
| parser.add_argument('model_config', type=dict, required=True, location='json') | parser.add_argument('model_config', type=dict, required=True, location='json') | ||||
| parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json') | |||||
| args = parser.parse_args() | args = parser.parse_args() | ||||
| streaming = args['response_mode'] != 'blocking' | |||||
| account = flask_login.current_user | account = flask_login.current_user | ||||
| try: | try: | ||||
| user=account, | user=account, | ||||
| args=args, | args=args, | ||||
| from_source='console', | from_source='console', | ||||
| streaming=True, | |||||
| streaming=streaming, | |||||
| is_model_config_override=True | is_model_config_override=True | ||||
| ) | ) | ||||
| parser.add_argument('query', type=str, required=True, location='json') | parser.add_argument('query', type=str, required=True, location='json') | ||||
| parser.add_argument('model_config', type=dict, required=True, location='json') | parser.add_argument('model_config', type=dict, required=True, location='json') | ||||
| parser.add_argument('conversation_id', type=uuid_value, location='json') | parser.add_argument('conversation_id', type=uuid_value, location='json') | ||||
| parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json') | |||||
| args = parser.parse_args() | args = parser.parse_args() | ||||
| streaming = args['response_mode'] != 'blocking' | |||||
| account = flask_login.current_user | account = flask_login.current_user | ||||
| try: | try: | ||||
| user=account, | user=account, | ||||
| args=args, | args=args, | ||||
| from_source='console', | from_source='console', | ||||
| streaming=True, | |||||
| streaming=streaming, | |||||
| is_model_config_override=True | is_model_config_override=True | ||||
| ) | ) | ||||
| from controllers.console.setup import setup_required | from controllers.console.setup import setup_required | ||||
| from controllers.console.wraps import account_initialization_required | from controllers.console.wraps import account_initialization_required | ||||
| from core.generator.llm_generator import LLMGenerator | from core.generator.llm_generator import LLMGenerator | ||||
| from core.llm.error import ProviderTokenNotInitError, QuotaExceededError, LLMBadRequestError, LLMAPIConnectionError, \ | |||||
| from core.model_providers.error import ProviderTokenNotInitError, QuotaExceededError, LLMBadRequestError, LLMAPIConnectionError, \ | |||||
| LLMAPIUnavailableError, LLMRateLimitError, LLMAuthorizationError, ModelCurrentlyNotSupportError | LLMAPIUnavailableError, LLMRateLimitError, LLMAuthorizationError, ModelCurrentlyNotSupportError | ||||
| AppMoreLikeThisDisabledError, ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError | AppMoreLikeThisDisabledError, ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError | ||||
| from controllers.console.setup import setup_required | from controllers.console.setup import setup_required | ||||
| from controllers.console.wraps import account_initialization_required | from controllers.console.wraps import account_initialization_required | ||||
| from core.llm.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \ | |||||
| from core.model_providers.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \ | |||||
| ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError | ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError | ||||
| from libs.helper import uuid_value, TimestampField | from libs.helper import uuid_value, TimestampField | ||||
| from libs.infinite_scroll_pagination import InfiniteScrollPagination | from libs.infinite_scroll_pagination import InfiniteScrollPagination |
| # validate config | # validate config | ||||
| model_configuration = AppModelConfigService.validate_configuration( | model_configuration = AppModelConfigService.validate_configuration( | ||||
| tenant_id=current_user.current_tenant_id, | |||||
| account=current_user, | account=current_user, | ||||
| config=request.json, | |||||
| mode=app_model.mode | |||||
| config=request.json | |||||
| ) | ) | ||||
| new_app_model_config = AppModelConfig( | new_app_model_config = AppModelConfig( |
| # validate args | # validate args | ||||
| DocumentService.estimate_args_validate(args) | DocumentService.estimate_args_validate(args) | ||||
| indexing_runner = IndexingRunner() | indexing_runner = IndexingRunner() | ||||
| response = indexing_runner.notion_indexing_estimate(args['notion_info_list'], args['process_rule']) | |||||
| response = indexing_runner.notion_indexing_estimate(current_user.current_tenant_id, args['notion_info_list'], args['process_rule']) | |||||
| return response, 200 | return response, 200 | ||||
| from werkzeug.exceptions import NotFound, Forbidden | from werkzeug.exceptions import NotFound, Forbidden | ||||
| import services | import services | ||||
| from controllers.console import api | from controllers.console import api | ||||
| from controllers.console.app.error import ProviderNotInitializeError | |||||
| from controllers.console.datasets.error import DatasetNameDuplicateError | from controllers.console.datasets.error import DatasetNameDuplicateError | ||||
| from controllers.console.setup import setup_required | from controllers.console.setup import setup_required | ||||
| from controllers.console.wraps import account_initialization_required | from controllers.console.wraps import account_initialization_required | ||||
| from core.indexing_runner import IndexingRunner | from core.indexing_runner import IndexingRunner | ||||
| from core.model_providers.error import LLMBadRequestError | |||||
| from core.model_providers.model_factory import ModelFactory | |||||
| from libs.helper import TimestampField | from libs.helper import TimestampField | ||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| from models.dataset import DocumentSegment, Document | from models.dataset import DocumentSegment, Document | ||||
| if current_user.current_tenant.current_role not in ['admin', 'owner']: | if current_user.current_tenant.current_role not in ['admin', 'owner']: | ||||
| raise Forbidden() | raise Forbidden() | ||||
| try: | |||||
| ModelFactory.get_embedding_model( | |||||
| tenant_id=current_user.current_tenant_id | |||||
| ) | |||||
| except LLMBadRequestError: | |||||
| raise ProviderNotInitializeError( | |||||
| f"No Embedding Model available. Please configure a valid provider " | |||||
| f"in the Settings -> Model Provider.") | |||||
| try: | try: | ||||
| dataset = DatasetService.create_empty_dataset( | dataset = DatasetService.create_empty_dataset( | ||||
| tenant_id=current_user.current_tenant_id, | tenant_id=current_user.current_tenant_id, | ||||
| raise NotFound("File not found.") | raise NotFound("File not found.") | ||||
| indexing_runner = IndexingRunner() | indexing_runner = IndexingRunner() | ||||
| response = indexing_runner.file_indexing_estimate(file_details, args['process_rule'], args['doc_form']) | |||||
| try: | |||||
| response = indexing_runner.file_indexing_estimate(current_user.current_tenant_id, file_details, | |||||
| args['process_rule'], args['doc_form']) | |||||
| except LLMBadRequestError: | |||||
| raise ProviderNotInitializeError( | |||||
| f"No Embedding Model available. Please configure a valid provider " | |||||
| f"in the Settings -> Model Provider.") | |||||
| elif args['info_list']['data_source_type'] == 'notion_import': | elif args['info_list']['data_source_type'] == 'notion_import': | ||||
| indexing_runner = IndexingRunner() | indexing_runner = IndexingRunner() | ||||
| response = indexing_runner.notion_indexing_estimate(args['info_list']['notion_info_list'], | |||||
| args['process_rule'], args['doc_form']) | |||||
| try: | |||||
| response = indexing_runner.notion_indexing_estimate(current_user.current_tenant_id, | |||||
| args['info_list']['notion_info_list'], | |||||
| args['process_rule'], args['doc_form']) | |||||
| except LLMBadRequestError: | |||||
| raise ProviderNotInitializeError( | |||||
| f"No Embedding Model available. Please configure a valid provider " | |||||
| f"in the Settings -> Model Provider.") | |||||
| else: | else: | ||||
| raise ValueError('Data source type not support') | raise ValueError('Data source type not support') | ||||
| return response, 200 | return response, 200 |
| from controllers.console.setup import setup_required | from controllers.console.setup import setup_required | ||||
| from controllers.console.wraps import account_initialization_required | from controllers.console.wraps import account_initialization_required | ||||
| from core.indexing_runner import IndexingRunner | from core.indexing_runner import IndexingRunner | ||||
| from core.llm.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError | |||||
| from core.model_providers.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError, \ | |||||
| LLMBadRequestError | |||||
| from core.model_providers.model_factory import ModelFactory | |||||
| from extensions.ext_redis import redis_client | from extensions.ext_redis import redis_client | ||||
| from libs.helper import TimestampField | from libs.helper import TimestampField | ||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| # validate args | # validate args | ||||
| DocumentService.document_create_args_validate(args) | DocumentService.document_create_args_validate(args) | ||||
| try: | |||||
| ModelFactory.get_embedding_model( | |||||
| tenant_id=current_user.current_tenant_id | |||||
| ) | |||||
| except LLMBadRequestError: | |||||
| raise ProviderNotInitializeError( | |||||
| f"No Embedding Model available. Please configure a valid provider " | |||||
| f"in the Settings -> Model Provider.") | |||||
| try: | try: | ||||
| documents, batch = DocumentService.save_document_with_dataset_id(dataset, args, current_user) | documents, batch = DocumentService.save_document_with_dataset_id(dataset, args, current_user) | ||||
| except ProviderTokenNotInitError as ex: | except ProviderTokenNotInitError as ex: | ||||
| parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json') | parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json') | ||||
| args = parser.parse_args() | args = parser.parse_args() | ||||
| try: | |||||
| ModelFactory.get_embedding_model( | |||||
| tenant_id=current_user.current_tenant_id | |||||
| ) | |||||
| except LLMBadRequestError: | |||||
| raise ProviderNotInitializeError( | |||||
| f"No Embedding Model available. Please configure a valid provider " | |||||
| f"in the Settings -> Model Provider.") | |||||
| # validate args | # validate args | ||||
| DocumentService.document_create_args_validate(args) | DocumentService.document_create_args_validate(args) | ||||
| indexing_runner = IndexingRunner() | indexing_runner = IndexingRunner() | ||||
| response = indexing_runner.file_indexing_estimate([file], data_process_rule_dict) | |||||
| try: | |||||
| response = indexing_runner.file_indexing_estimate(current_user.current_tenant_id, [file], | |||||
| data_process_rule_dict) | |||||
| except LLMBadRequestError: | |||||
| raise ProviderNotInitializeError( | |||||
| f"No Embedding Model available. Please configure a valid provider " | |||||
| f"in the Settings -> Model Provider.") | |||||
| return response | return response | ||||
| raise NotFound("File not found.") | raise NotFound("File not found.") | ||||
| indexing_runner = IndexingRunner() | indexing_runner = IndexingRunner() | ||||
| response = indexing_runner.file_indexing_estimate(file_details, data_process_rule_dict) | |||||
| try: | |||||
| response = indexing_runner.file_indexing_estimate(current_user.current_tenant_id, file_details, | |||||
| data_process_rule_dict) | |||||
| except LLMBadRequestError: | |||||
| raise ProviderNotInitializeError( | |||||
| f"No Embedding Model available. Please configure a valid provider " | |||||
| f"in the Settings -> Model Provider.") | |||||
| elif dataset.data_source_type: | elif dataset.data_source_type: | ||||
| indexing_runner = IndexingRunner() | indexing_runner = IndexingRunner() | ||||
| response = indexing_runner.notion_indexing_estimate(info_list, | |||||
| data_process_rule_dict) | |||||
| try: | |||||
| response = indexing_runner.notion_indexing_estimate(current_user.current_tenant_id, | |||||
| info_list, | |||||
| data_process_rule_dict) | |||||
| except LLMBadRequestError: | |||||
| raise ProviderNotInitializeError( | |||||
| f"No Embedding Model available. Please configure a valid provider " | |||||
| f"in the Settings -> Model Provider.") | |||||
| else: | else: | ||||
| raise ValueError('Data source type not support') | raise ValueError('Data source type not support') | ||||
| return response | return response |
| from controllers.console.datasets.error import HighQualityDatasetOnlyError, DatasetNotInitializedError | from controllers.console.datasets.error import HighQualityDatasetOnlyError, DatasetNotInitializedError | ||||
| from controllers.console.setup import setup_required | from controllers.console.setup import setup_required | ||||
| from controllers.console.wraps import account_initialization_required | from controllers.console.wraps import account_initialization_required | ||||
| from core.llm.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError | |||||
| from core.model_providers.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError | |||||
| from libs.helper import TimestampField | from libs.helper import TimestampField | ||||
| from services.dataset_service import DatasetService | from services.dataset_service import DatasetService | ||||
| from services.hit_testing_service import HitTestingService | from services.hit_testing_service import HitTestingService | ||||
| raise ProviderQuotaExceededError() | raise ProviderQuotaExceededError() | ||||
| except ModelCurrentlyNotSupportError: | except ModelCurrentlyNotSupportError: | ||||
| raise ProviderModelCurrentlyNotSupportError() | raise ProviderModelCurrentlyNotSupportError() | ||||
| except ValueError as e: | |||||
| raise ValueError(str(e)) | |||||
| except Exception as e: | except Exception as e: | ||||
| logging.exception("Hit testing failed.") | logging.exception("Hit testing failed.") | ||||
| raise InternalServerError(str(e)) | raise InternalServerError(str(e)) |
| NoAudioUploadedError, AudioTooLargeError, \ | NoAudioUploadedError, AudioTooLargeError, \ | ||||
| UnsupportedAudioTypeError, ProviderNotSupportSpeechToTextError | UnsupportedAudioTypeError, ProviderNotSupportSpeechToTextError | ||||
| from controllers.console.explore.wraps import InstalledAppResource | from controllers.console.explore.wraps import InstalledAppResource | ||||
| from core.llm.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \ | |||||
| from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \ | |||||
| LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError | LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError | ||||
| from services.audio_service import AudioService | from services.audio_service import AudioService | ||||
| from services.errors.audio import NoAudioUploadedServiceError, AudioTooLargeServiceError, \ | from services.errors.audio import NoAudioUploadedServiceError, AudioTooLargeServiceError, \ |
| from controllers.console.explore.error import NotCompletionAppError, NotChatAppError | from controllers.console.explore.error import NotCompletionAppError, NotChatAppError | ||||
| from controllers.console.explore.wraps import InstalledAppResource | from controllers.console.explore.wraps import InstalledAppResource | ||||
| from core.conversation_message_task import PubHandler | from core.conversation_message_task import PubHandler | ||||
| from core.llm.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \ | |||||
| from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \ | |||||
| LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError | LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError | ||||
| from libs.helper import uuid_value | from libs.helper import uuid_value | ||||
| from services.completion_service import CompletionService | from services.completion_service import CompletionService |
| ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, CompletionRequestError | ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, CompletionRequestError | ||||
| from controllers.console.explore.error import NotCompletionAppError, AppSuggestedQuestionsAfterAnswerDisabledError | from controllers.console.explore.error import NotCompletionAppError, AppSuggestedQuestionsAfterAnswerDisabledError | ||||
| from controllers.console.explore.wraps import InstalledAppResource | from controllers.console.explore.wraps import InstalledAppResource | ||||
| from core.llm.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \ | |||||
| from core.model_providers.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \ | |||||
| ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError | ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError | ||||
| from libs.helper import uuid_value, TimestampField | from libs.helper import uuid_value, TimestampField | ||||
| from services.completion_service import CompletionService | from services.completion_service import CompletionService |
| from controllers.console import api | from controllers.console import api | ||||
| from controllers.console.explore.wraps import InstalledAppResource | from controllers.console.explore.wraps import InstalledAppResource | ||||
| from core.llm.llm_builder import LLMBuilder | |||||
| from models.provider import ProviderName | |||||
| from models.model import InstalledApp | from models.model import InstalledApp | ||||
| """Retrieve app parameters.""" | """Retrieve app parameters.""" | ||||
| app_model = installed_app.app | app_model = installed_app.app | ||||
| app_model_config = app_model.app_model_config | app_model_config = app_model.app_model_config | ||||
| provider_name = LLMBuilder.get_default_provider(installed_app.tenant_id, 'whisper-1') | |||||
| return { | return { | ||||
| 'opening_statement': app_model_config.opening_statement, | 'opening_statement': app_model_config.opening_statement, | ||||
| 'suggested_questions': app_model_config.suggested_questions_list, | 'suggested_questions': app_model_config.suggested_questions_list, | ||||
| 'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict, | 'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict, | ||||
| 'speech_to_text': app_model_config.speech_to_text_dict if provider_name == ProviderName.OPENAI.value else { 'enabled': False }, | |||||
| 'speech_to_text': app_model_config.speech_to_text_dict, | |||||
| 'more_like_this': app_model_config.more_like_this_dict, | 'more_like_this': app_model_config.more_like_this_dict, | ||||
| 'user_input_form': app_model_config.user_input_form_list | 'user_input_form': app_model_config.user_input_form_list | ||||
| } | } |
| NoAudioUploadedError, AudioTooLargeError, \ | NoAudioUploadedError, AudioTooLargeError, \ | ||||
| UnsupportedAudioTypeError, ProviderNotSupportSpeechToTextError | UnsupportedAudioTypeError, ProviderNotSupportSpeechToTextError | ||||
| from controllers.console.universal_chat.wraps import UniversalChatResource | from controllers.console.universal_chat.wraps import UniversalChatResource | ||||
| from core.llm.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \ | |||||
| from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \ | |||||
| LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError | LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError | ||||
| from services.audio_service import AudioService | from services.audio_service import AudioService | ||||
| from services.errors.audio import NoAudioUploadedServiceError, AudioTooLargeServiceError, \ | from services.errors.audio import NoAudioUploadedServiceError, AudioTooLargeServiceError, \ |
| from controllers.console.app.error import ConversationCompletedError, AppUnavailableError, ProviderNotInitializeError, \ | from controllers.console.app.error import ConversationCompletedError, AppUnavailableError, ProviderNotInitializeError, \ | ||||
| ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, CompletionRequestError | ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, CompletionRequestError | ||||
| from controllers.console.universal_chat.wraps import UniversalChatResource | from controllers.console.universal_chat.wraps import UniversalChatResource | ||||
| from core.constant import llm_constant | |||||
| from core.conversation_message_task import PubHandler | from core.conversation_message_task import PubHandler | ||||
| from core.llm.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError, \ | |||||
| from core.model_providers.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError, \ | |||||
| LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, LLMRateLimitError, LLMAuthorizationError | LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, LLMRateLimitError, LLMAuthorizationError | ||||
| from libs.helper import uuid_value | from libs.helper import uuid_value | ||||
| from services.completion_service import CompletionService | from services.completion_service import CompletionService | ||||
| parser = reqparse.RequestParser() | parser = reqparse.RequestParser() | ||||
| parser.add_argument('query', type=str, required=True, location='json') | parser.add_argument('query', type=str, required=True, location='json') | ||||
| parser.add_argument('conversation_id', type=uuid_value, location='json') | parser.add_argument('conversation_id', type=uuid_value, location='json') | ||||
| parser.add_argument('provider', type=str, required=True, location='json') | |||||
| parser.add_argument('model', type=str, required=True, location='json') | parser.add_argument('model', type=str, required=True, location='json') | ||||
| parser.add_argument('tools', type=list, required=True, location='json') | parser.add_argument('tools', type=list, required=True, location='json') | ||||
| args = parser.parse_args() | args = parser.parse_args() | ||||
| # update app model config | # update app model config | ||||
| args['model_config'] = app_model_config.to_dict() | args['model_config'] = app_model_config.to_dict() | ||||
| args['model_config']['model']['name'] = args['model'] | args['model_config']['model']['name'] = args['model'] | ||||
| if not llm_constant.models[args['model']]: | |||||
| raise ValueError("Model not exists.") | |||||
| args['model_config']['model']['provider'] = llm_constant.models[args['model']] | |||||
| args['model_config']['model']['provider'] = args['provider'] | |||||
| args['model_config']['agent_mode']['tools'] = args['tools'] | args['model_config']['agent_mode']['tools'] = args['tools'] | ||||
| if not args['model_config']['agent_mode']['tools']: | if not args['model_config']['agent_mode']['tools']: |
| ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, CompletionRequestError | ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, CompletionRequestError | ||||
| from controllers.console.explore.error import AppSuggestedQuestionsAfterAnswerDisabledError | from controllers.console.explore.error import AppSuggestedQuestionsAfterAnswerDisabledError | ||||
| from controllers.console.universal_chat.wraps import UniversalChatResource | from controllers.console.universal_chat.wraps import UniversalChatResource | ||||
| from core.llm.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \ | |||||
| from core.model_providers.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \ | |||||
| ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError | ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError | ||||
| from libs.helper import uuid_value, TimestampField | from libs.helper import uuid_value, TimestampField | ||||
| from services.errors.conversation import ConversationNotExistsError | from services.errors.conversation import ConversationNotExistsError |
| from controllers.console import api | from controllers.console import api | ||||
| from controllers.console.universal_chat.wraps import UniversalChatResource | from controllers.console.universal_chat.wraps import UniversalChatResource | ||||
| from core.llm.llm_builder import LLMBuilder | |||||
| from models.provider import ProviderName | |||||
| from models.model import App | from models.model import App | ||||
| """Retrieve app parameters.""" | """Retrieve app parameters.""" | ||||
| app_model = universal_app | app_model = universal_app | ||||
| app_model_config = app_model.app_model_config | app_model_config = app_model.app_model_config | ||||
| provider_name = LLMBuilder.get_default_provider(universal_app.tenant_id, 'whisper-1') | |||||
| return { | return { | ||||
| 'opening_statement': app_model_config.opening_statement, | 'opening_statement': app_model_config.opening_statement, | ||||
| 'suggested_questions': app_model_config.suggested_questions_list, | 'suggested_questions': app_model_config.suggested_questions_list, | ||||
| 'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict, | 'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict, | ||||
| 'speech_to_text': app_model_config.speech_to_text_dict if provider_name == ProviderName.OPENAI.value else { 'enabled': False }, | |||||
| 'speech_to_text': app_model_config.speech_to_text_dict, | |||||
| } | } | ||||
| import logging | |||||
| import stripe | |||||
| from flask import request, current_app | |||||
| from flask_restful import Resource | |||||
| from controllers.console import api | |||||
| from controllers.console.setup import setup_required | |||||
| from controllers.console.wraps import only_edition_cloud | |||||
| from services.provider_checkout_service import ProviderCheckoutService | |||||
| class StripeWebhookApi(Resource): | |||||
| @setup_required | |||||
| @only_edition_cloud | |||||
| def post(self): | |||||
| payload = request.data | |||||
| sig_header = request.headers.get('STRIPE_SIGNATURE') | |||||
| webhook_secret = current_app.config.get('STRIPE_WEBHOOK_SECRET') | |||||
| try: | |||||
| event = stripe.Webhook.construct_event( | |||||
| payload, sig_header, webhook_secret | |||||
| ) | |||||
| except ValueError as e: | |||||
| # Invalid payload | |||||
| return 'Invalid payload', 400 | |||||
| except stripe.error.SignatureVerificationError as e: | |||||
| # Invalid signature | |||||
| return 'Invalid signature', 400 | |||||
| # Handle the checkout.session.completed event | |||||
| if event['type'] == 'checkout.session.completed': | |||||
| logging.debug(event['data']['object']['id']) | |||||
| logging.debug(event['data']['object']['amount_subtotal']) | |||||
| logging.debug(event['data']['object']['currency']) | |||||
| logging.debug(event['data']['object']['payment_intent']) | |||||
| logging.debug(event['data']['object']['payment_status']) | |||||
| logging.debug(event['data']['object']['metadata']) | |||||
| # Fulfill the purchase... | |||||
| provider_checkout_service = ProviderCheckoutService() | |||||
| try: | |||||
| provider_checkout_service.fulfill_provider_order(event) | |||||
| except Exception as e: | |||||
| logging.debug(str(e)) | |||||
| return 'success', 200 | |||||
| return 'success', 200 | |||||
| api.add_resource(StripeWebhookApi, '/webhook/stripe') |
| # -*- coding:utf-8 -*- | |||||
| import base64 | |||||
| import json | |||||
| import logging | |||||
| from flask import current_app | |||||
| from flask_login import login_required, current_user | from flask_login import login_required, current_user | ||||
| from flask_restful import Resource, reqparse, abort | |||||
| from flask_restful import Resource, reqparse | |||||
| from werkzeug.exceptions import Forbidden | from werkzeug.exceptions import Forbidden | ||||
| from controllers.console import api | from controllers.console import api | ||||
| from controllers.console.app.error import ProviderNotInitializeError | |||||
| from controllers.console.setup import setup_required | from controllers.console.setup import setup_required | ||||
| from controllers.console.wraps import account_initialization_required | from controllers.console.wraps import account_initialization_required | ||||
| from core.llm.provider.errors import ValidateFailedError | |||||
| from extensions.ext_database import db | |||||
| from libs import rsa | |||||
| from models.provider import Provider, ProviderType, ProviderName | |||||
| from core.model_providers.error import LLMBadRequestError | |||||
| from core.model_providers.providers.base import CredentialsValidateFailedError | |||||
| from services.provider_checkout_service import ProviderCheckoutService | |||||
| from services.provider_service import ProviderService | from services.provider_service import ProviderService | ||||
| class ProviderListApi(Resource): | |||||
| class ModelProviderListApi(Resource): | |||||
| @setup_required | @setup_required | ||||
| @login_required | @login_required | ||||
| def get(self): | def get(self): | ||||
| tenant_id = current_user.current_tenant_id | tenant_id = current_user.current_tenant_id | ||||
| """ | |||||
| If the type is AZURE_OPENAI, decode and return the four fields of azure_api_type, azure_api_version:, | |||||
| azure_api_base, azure_api_key as an object, where azure_api_key displays the first 6 bits in plaintext, and the | |||||
| rest is replaced by * and the last two bits are displayed in plaintext | |||||
| If the type is other, decode and return the Token field directly, the field displays the first 6 bits in | |||||
| plaintext, the rest is replaced by * and the last two bits are displayed in plaintext | |||||
| """ | |||||
| ProviderService.init_supported_provider(current_user.current_tenant) | |||||
| providers = Provider.query.filter_by(tenant_id=tenant_id).all() | |||||
| provider_list = [ | |||||
| { | |||||
| 'provider_name': p.provider_name, | |||||
| 'provider_type': p.provider_type, | |||||
| 'is_valid': p.is_valid, | |||||
| 'last_used': p.last_used, | |||||
| 'is_enabled': p.is_enabled, | |||||
| **({ | |||||
| 'quota_type': p.quota_type, | |||||
| 'quota_limit': p.quota_limit, | |||||
| 'quota_used': p.quota_used | |||||
| } if p.provider_type == ProviderType.SYSTEM.value else {}), | |||||
| 'token': ProviderService.get_obfuscated_api_key(current_user.current_tenant, | |||||
| ProviderName(p.provider_name), only_custom=True) | |||||
| if p.provider_type == ProviderType.CUSTOM.value else None | |||||
| } | |||||
| for p in providers | |||||
| ] | |||||
| provider_service = ProviderService() | |||||
| provider_list = provider_service.get_provider_list(tenant_id) | |||||
| return provider_list | return provider_list | ||||
| class ProviderTokenApi(Resource): | |||||
| class ModelProviderValidateApi(Resource): | |||||
| @setup_required | @setup_required | ||||
| @login_required | @login_required | ||||
| @account_initialization_required | @account_initialization_required | ||||
| def post(self, provider): | |||||
| if provider not in [p.value for p in ProviderName]: | |||||
| abort(404) | |||||
| def post(self, provider_name: str): | |||||
| parser = reqparse.RequestParser() | |||||
| parser.add_argument('config', type=dict, required=True, nullable=False, location='json') | |||||
| args = parser.parse_args() | |||||
| provider_service = ProviderService() | |||||
| result = True | |||||
| error = None | |||||
| try: | |||||
| provider_service.custom_provider_config_validate( | |||||
| provider_name=provider_name, | |||||
| config=args['config'] | |||||
| ) | |||||
| except CredentialsValidateFailedError as ex: | |||||
| result = False | |||||
| error = str(ex) | |||||
| response = {'result': 'success' if result else 'error'} | |||||
| if not result: | |||||
| response['error'] = error | |||||
| return response | |||||
| class ModelProviderUpdateApi(Resource): | |||||
| # The role of the current user in the ta table must be admin or owner | |||||
| @setup_required | |||||
| @login_required | |||||
| @account_initialization_required | |||||
| def post(self, provider_name: str): | |||||
| if current_user.current_tenant.current_role not in ['admin', 'owner']: | if current_user.current_tenant.current_role not in ['admin', 'owner']: | ||||
| logging.log(logging.ERROR, | |||||
| f'User {current_user.id} is not authorized to update provider token, current_role is {current_user.current_tenant.current_role}') | |||||
| raise Forbidden() | raise Forbidden() | ||||
| parser = reqparse.RequestParser() | parser = reqparse.RequestParser() | ||||
| parser.add_argument('token', type=ProviderService.get_token_type( | |||||
| tenant=current_user.current_tenant, | |||||
| provider_name=ProviderName(provider) | |||||
| ), required=True, nullable=False, location='json') | |||||
| parser.add_argument('config', type=dict, required=True, nullable=False, location='json') | |||||
| args = parser.parse_args() | args = parser.parse_args() | ||||
| if args['token']: | |||||
| try: | |||||
| ProviderService.validate_provider_configs( | |||||
| tenant=current_user.current_tenant, | |||||
| provider_name=ProviderName(provider), | |||||
| configs=args['token'] | |||||
| ) | |||||
| token_is_valid = True | |||||
| except ValidateFailedError as ex: | |||||
| raise ValueError(str(ex)) | |||||
| base64_encrypted_token = ProviderService.get_encrypted_token( | |||||
| tenant=current_user.current_tenant, | |||||
| provider_name=ProviderName(provider), | |||||
| configs=args['token'] | |||||
| provider_service = ProviderService() | |||||
| try: | |||||
| provider_service.save_custom_provider_config( | |||||
| tenant_id=current_user.current_tenant_id, | |||||
| provider_name=provider_name, | |||||
| config=args['config'] | |||||
| ) | ) | ||||
| else: | |||||
| base64_encrypted_token = None | |||||
| token_is_valid = False | |||||
| tenant = current_user.current_tenant | |||||
| provider_model = db.session.query(Provider).filter( | |||||
| Provider.tenant_id == tenant.id, | |||||
| Provider.provider_name == provider, | |||||
| Provider.provider_type == ProviderType.CUSTOM.value | |||||
| ).first() | |||||
| # Only allow updating token for CUSTOM provider type | |||||
| if provider_model: | |||||
| provider_model.encrypted_config = base64_encrypted_token | |||||
| provider_model.is_valid = token_is_valid | |||||
| else: | |||||
| provider_model = Provider(tenant_id=tenant.id, provider_name=provider, | |||||
| provider_type=ProviderType.CUSTOM.value, | |||||
| encrypted_config=base64_encrypted_token, | |||||
| is_valid=token_is_valid) | |||||
| db.session.add(provider_model) | |||||
| if provider in [ProviderName.OPENAI.value, ProviderName.AZURE_OPENAI.value] and provider_model.is_valid: | |||||
| other_providers = db.session.query(Provider).filter( | |||||
| Provider.tenant_id == tenant.id, | |||||
| Provider.provider_name.in_([ProviderName.OPENAI.value, ProviderName.AZURE_OPENAI.value]), | |||||
| Provider.provider_name != provider, | |||||
| Provider.provider_type == ProviderType.CUSTOM.value | |||||
| ).all() | |||||
| for other_provider in other_providers: | |||||
| other_provider.is_valid = False | |||||
| db.session.commit() | |||||
| if provider in [ProviderName.AZURE_OPENAI.value, ProviderName.COHERE.value, | |||||
| ProviderName.HUGGINGFACEHUB.value]: | |||||
| return {'result': 'success', 'warning': 'MOCK: This provider is not supported yet.'}, 201 | |||||
| except CredentialsValidateFailedError as ex: | |||||
| raise ValueError(str(ex)) | |||||
| return {'result': 'success'}, 201 | return {'result': 'success'}, 201 | ||||
| @setup_required | |||||
| @login_required | |||||
| @account_initialization_required | |||||
| def delete(self, provider_name: str): | |||||
| if current_user.current_tenant.current_role not in ['admin', 'owner']: | |||||
| raise Forbidden() | |||||
| class ProviderTokenValidateApi(Resource): | |||||
| provider_service = ProviderService() | |||||
| provider_service.delete_custom_provider( | |||||
| tenant_id=current_user.current_tenant_id, | |||||
| provider_name=provider_name | |||||
| ) | |||||
| return {'result': 'success'}, 204 | |||||
| class ModelProviderModelValidateApi(Resource): | |||||
| @setup_required | @setup_required | ||||
| @login_required | @login_required | ||||
| @account_initialization_required | @account_initialization_required | ||||
| def post(self, provider): | |||||
| if provider not in [p.value for p in ProviderName]: | |||||
| abort(404) | |||||
| def post(self, provider_name: str): | |||||
| parser = reqparse.RequestParser() | parser = reqparse.RequestParser() | ||||
| parser.add_argument('token', type=ProviderService.get_token_type( | |||||
| tenant=current_user.current_tenant, | |||||
| provider_name=ProviderName(provider) | |||||
| ), required=True, nullable=False, location='json') | |||||
| parser.add_argument('model_name', type=str, required=True, nullable=False, location='json') | |||||
| parser.add_argument('model_type', type=str, required=True, nullable=False, | |||||
| choices=['text-generation', 'embeddings', 'speech2text'], location='json') | |||||
| parser.add_argument('config', type=dict, required=True, nullable=False, location='json') | |||||
| args = parser.parse_args() | args = parser.parse_args() | ||||
| # todo: remove this when the provider is supported | |||||
| if provider in [ProviderName.COHERE.value, | |||||
| ProviderName.HUGGINGFACEHUB.value]: | |||||
| return {'result': 'success', 'warning': 'MOCK: This provider is not supported yet.'} | |||||
| provider_service = ProviderService() | |||||
| result = True | result = True | ||||
| error = None | error = None | ||||
| try: | try: | ||||
| ProviderService.validate_provider_configs( | |||||
| tenant=current_user.current_tenant, | |||||
| provider_name=ProviderName(provider), | |||||
| configs=args['token'] | |||||
| provider_service.custom_provider_model_config_validate( | |||||
| provider_name=provider_name, | |||||
| model_name=args['model_name'], | |||||
| model_type=args['model_type'], | |||||
| config=args['config'] | |||||
| ) | ) | ||||
| except ValidateFailedError as e: | |||||
| except CredentialsValidateFailedError as ex: | |||||
| result = False | result = False | ||||
| error = str(e) | |||||
| error = str(ex) | |||||
| response = {'result': 'success' if result else 'error'} | response = {'result': 'success' if result else 'error'} | ||||
| return response | return response | ||||
| class ProviderSystemApi(Resource): | |||||
| class ModelProviderModelUpdateApi(Resource): | |||||
| @setup_required | @setup_required | ||||
| @login_required | @login_required | ||||
| @account_initialization_required | @account_initialization_required | ||||
| def put(self, provider): | |||||
| if provider not in [p.value for p in ProviderName]: | |||||
| abort(404) | |||||
| def post(self, provider_name: str): | |||||
| if current_user.current_tenant.current_role not in ['admin', 'owner']: | |||||
| raise Forbidden() | |||||
| parser = reqparse.RequestParser() | parser = reqparse.RequestParser() | ||||
| parser.add_argument('is_enabled', type=bool, required=True, location='json') | |||||
| parser.add_argument('model_name', type=str, required=True, nullable=False, location='json') | |||||
| parser.add_argument('model_type', type=str, required=True, nullable=False, | |||||
| choices=['text-generation', 'embeddings', 'speech2text'], location='json') | |||||
| parser.add_argument('config', type=dict, required=True, nullable=False, location='json') | |||||
| args = parser.parse_args() | args = parser.parse_args() | ||||
| tenant = current_user.current_tenant_id | |||||
| provider_model = Provider.query.filter_by(tenant_id=tenant.id, provider_name=provider).first() | |||||
| if provider_model and provider_model.provider_type == ProviderType.SYSTEM.value: | |||||
| provider_model.is_valid = args['is_enabled'] | |||||
| db.session.commit() | |||||
| elif not provider_model: | |||||
| if provider == ProviderName.OPENAI.value: | |||||
| quota_limit = current_app.config['OPENAI_HOSTED_QUOTA_LIMIT'] | |||||
| elif provider == ProviderName.ANTHROPIC.value: | |||||
| quota_limit = current_app.config['ANTHROPIC_HOSTED_QUOTA_LIMIT'] | |||||
| else: | |||||
| quota_limit = 0 | |||||
| ProviderService.create_system_provider( | |||||
| tenant, | |||||
| provider, | |||||
| quota_limit, | |||||
| args['is_enabled'] | |||||
| provider_service = ProviderService() | |||||
| try: | |||||
| provider_service.add_or_save_custom_provider_model_config( | |||||
| tenant_id=current_user.current_tenant_id, | |||||
| provider_name=provider_name, | |||||
| model_name=args['model_name'], | |||||
| model_type=args['model_type'], | |||||
| config=args['config'] | |||||
| ) | ) | ||||
| else: | |||||
| abort(403) | |||||
| except CredentialsValidateFailedError as ex: | |||||
| raise ValueError(str(ex)) | |||||
| return {'result': 'success'} | |||||
| return {'result': 'success'}, 200 | |||||
| @setup_required | @setup_required | ||||
| @login_required | @login_required | ||||
| @account_initialization_required | @account_initialization_required | ||||
| def get(self, provider): | |||||
| if provider not in [p.value for p in ProviderName]: | |||||
| abort(404) | |||||
| def delete(self, provider_name: str): | |||||
| if current_user.current_tenant.current_role not in ['admin', 'owner']: | |||||
| raise Forbidden() | |||||
| parser = reqparse.RequestParser() | |||||
| parser.add_argument('model_name', type=str, required=True, nullable=False, location='args') | |||||
| parser.add_argument('model_type', type=str, required=True, nullable=False, | |||||
| choices=['text-generation', 'embeddings', 'speech2text'], location='args') | |||||
| args = parser.parse_args() | |||||
| # The role of the current user in the ta table must be admin or owner | |||||
| provider_service = ProviderService() | |||||
| provider_service.delete_custom_provider_model( | |||||
| tenant_id=current_user.current_tenant_id, | |||||
| provider_name=provider_name, | |||||
| model_name=args['model_name'], | |||||
| model_type=args['model_type'] | |||||
| ) | |||||
| return {'result': 'success'}, 204 | |||||
| class PreferredProviderTypeUpdateApi(Resource): | |||||
| @setup_required | |||||
| @login_required | |||||
| @account_initialization_required | |||||
| def post(self, provider_name: str): | |||||
| if current_user.current_tenant.current_role not in ['admin', 'owner']: | if current_user.current_tenant.current_role not in ['admin', 'owner']: | ||||
| raise Forbidden() | raise Forbidden() | ||||
| provider_model = db.session.query(Provider).filter(Provider.tenant_id == current_user.current_tenant_id, | |||||
| Provider.provider_name == provider, | |||||
| Provider.provider_type == ProviderType.SYSTEM.value).first() | |||||
| system_model = None | |||||
| if provider_model: | |||||
| system_model = { | |||||
| 'result': 'success', | |||||
| 'provider': { | |||||
| 'provider_name': provider_model.provider_name, | |||||
| 'provider_type': provider_model.provider_type, | |||||
| 'is_valid': provider_model.is_valid, | |||||
| 'last_used': provider_model.last_used, | |||||
| 'is_enabled': provider_model.is_enabled, | |||||
| 'quota_type': provider_model.quota_type, | |||||
| 'quota_limit': provider_model.quota_limit, | |||||
| 'quota_used': provider_model.quota_used | |||||
| } | |||||
| } | |||||
| else: | |||||
| abort(404) | |||||
| parser = reqparse.RequestParser() | |||||
| parser.add_argument('preferred_provider_type', type=str, required=True, nullable=False, | |||||
| choices=['system', 'custom'], location='json') | |||||
| args = parser.parse_args() | |||||
| provider_service = ProviderService() | |||||
| provider_service.switch_preferred_provider( | |||||
| tenant_id=current_user.current_tenant_id, | |||||
| provider_name=provider_name, | |||||
| preferred_provider_type=args['preferred_provider_type'] | |||||
| ) | |||||
| return {'result': 'success'} | |||||
| class ModelProviderModelParameterRuleApi(Resource): | |||||
| @setup_required | |||||
| @login_required | |||||
| @account_initialization_required | |||||
| def get(self, provider_name: str): | |||||
| parser = reqparse.RequestParser() | |||||
| parser.add_argument('model_name', type=str, required=True, nullable=False, location='args') | |||||
| args = parser.parse_args() | |||||
| return system_model | |||||
| provider_service = ProviderService() | |||||
| try: | |||||
| parameter_rules = provider_service.get_model_parameter_rules( | |||||
| tenant_id=current_user.current_tenant_id, | |||||
| model_provider_name=provider_name, | |||||
| model_name=args['model_name'], | |||||
| model_type='text-generation' | |||||
| ) | |||||
| except LLMBadRequestError: | |||||
| raise ProviderNotInitializeError( | |||||
| f"Current Text Generation Model is invalid. Please switch to the available model.") | |||||
| rules = { | |||||
| k: { | |||||
| 'enabled': v.enabled, | |||||
| 'min': v.min, | |||||
| 'max': v.max, | |||||
| 'default': v.default | |||||
| } | |||||
| for k, v in vars(parameter_rules).items() | |||||
| } | |||||
| api.add_resource(ProviderTokenApi, '/providers/<provider>/token', | |||||
| endpoint='current_providers_token') # Deprecated | |||||
| api.add_resource(ProviderTokenValidateApi, '/providers/<provider>/token-validate', | |||||
| endpoint='current_providers_token_validate') # Deprecated | |||||
| return rules | |||||
| api.add_resource(ProviderTokenApi, '/workspaces/current/providers/<provider>/token', | |||||
| endpoint='workspaces_current_providers_token') # PUT for updating provider token | |||||
| api.add_resource(ProviderTokenValidateApi, '/workspaces/current/providers/<provider>/token-validate', | |||||
| endpoint='workspaces_current_providers_token_validate') # POST for validating provider token | |||||
| api.add_resource(ProviderListApi, '/workspaces/current/providers') # GET for getting providers list | |||||
| api.add_resource(ProviderSystemApi, '/workspaces/current/providers/<provider>/system', | |||||
| endpoint='workspaces_current_providers_system') # GET for getting provider quota, PUT for updating provider status | |||||
| class ModelProviderPaymentCheckoutUrlApi(Resource): | |||||
| @setup_required | |||||
| @login_required | |||||
| @account_initialization_required | |||||
| def get(self, provider_name: str): | |||||
| provider_service = ProviderCheckoutService() | |||||
| provider_checkout = provider_service.create_checkout( | |||||
| tenant_id=current_user.current_tenant_id, | |||||
| provider_name=provider_name, | |||||
| account=current_user | |||||
| ) | |||||
| return { | |||||
| 'url': provider_checkout.get_checkout_url() | |||||
| } | |||||
| api.add_resource(ModelProviderListApi, '/workspaces/current/model-providers') | |||||
| api.add_resource(ModelProviderValidateApi, '/workspaces/current/model-providers/<string:provider_name>/validate') | |||||
| api.add_resource(ModelProviderUpdateApi, '/workspaces/current/model-providers/<string:provider_name>') | |||||
| api.add_resource(ModelProviderModelValidateApi, | |||||
| '/workspaces/current/model-providers/<string:provider_name>/models/validate') | |||||
| api.add_resource(ModelProviderModelUpdateApi, | |||||
| '/workspaces/current/model-providers/<string:provider_name>/models') | |||||
| api.add_resource(PreferredProviderTypeUpdateApi, | |||||
| '/workspaces/current/model-providers/<string:provider_name>/preferred-provider-type') | |||||
| api.add_resource(ModelProviderModelParameterRuleApi, | |||||
| '/workspaces/current/model-providers/<string:provider_name>/models/parameter-rules') | |||||
| api.add_resource(ModelProviderPaymentCheckoutUrlApi, | |||||
| '/workspaces/current/model-providers/<string:provider_name>/checkout-url') |
| from flask_login import login_required, current_user | |||||
| from flask_restful import Resource, reqparse | |||||
| from controllers.console import api | |||||
| from controllers.console.setup import setup_required | |||||
| from controllers.console.wraps import account_initialization_required | |||||
| from core.model_providers.model_provider_factory import ModelProviderFactory | |||||
| from core.model_providers.models.entity.model_params import ModelType | |||||
| from models.provider import ProviderType | |||||
| from services.provider_service import ProviderService | |||||
| class DefaultModelApi(Resource): | |||||
| @setup_required | |||||
| @login_required | |||||
| @account_initialization_required | |||||
| def get(self): | |||||
| parser = reqparse.RequestParser() | |||||
| parser.add_argument('model_type', type=str, required=True, nullable=False, | |||||
| choices=['text-generation', 'embeddings', 'speech2text'], location='args') | |||||
| args = parser.parse_args() | |||||
| tenant_id = current_user.current_tenant_id | |||||
| provider_service = ProviderService() | |||||
| default_model = provider_service.get_default_model_of_model_type( | |||||
| tenant_id=tenant_id, | |||||
| model_type=args['model_type'] | |||||
| ) | |||||
| if not default_model: | |||||
| return None | |||||
| model_provider = ModelProviderFactory.get_preferred_model_provider( | |||||
| tenant_id, | |||||
| default_model.provider_name | |||||
| ) | |||||
| if not model_provider: | |||||
| return { | |||||
| 'model_name': default_model.model_name, | |||||
| 'model_type': default_model.model_type, | |||||
| 'model_provider': { | |||||
| 'provider_name': default_model.provider_name | |||||
| } | |||||
| } | |||||
| provider = model_provider.provider | |||||
| rst = { | |||||
| 'model_name': default_model.model_name, | |||||
| 'model_type': default_model.model_type, | |||||
| 'model_provider': { | |||||
| 'provider_name': provider.provider_name, | |||||
| 'provider_type': provider.provider_type | |||||
| } | |||||
| } | |||||
| model_provider_rules = ModelProviderFactory.get_provider_rule(default_model.provider_name) | |||||
| if provider.provider_type == ProviderType.SYSTEM.value: | |||||
| rst['model_provider']['quota_type'] = provider.quota_type | |||||
| rst['model_provider']['quota_unit'] = model_provider_rules['system_config']['quota_unit'] | |||||
| rst['model_provider']['quota_limit'] = provider.quota_limit | |||||
| rst['model_provider']['quota_used'] = provider.quota_used | |||||
| return rst | |||||
| @setup_required | |||||
| @login_required | |||||
| @account_initialization_required | |||||
| def post(self): | |||||
| parser = reqparse.RequestParser() | |||||
| parser.add_argument('model_name', type=str, required=True, nullable=False, location='json') | |||||
| parser.add_argument('model_type', type=str, required=True, nullable=False, | |||||
| choices=['text-generation', 'embeddings', 'speech2text'], location='json') | |||||
| parser.add_argument('provider_name', type=str, required=True, nullable=False, location='json') | |||||
| args = parser.parse_args() | |||||
| provider_service = ProviderService() | |||||
| provider_service.update_default_model_of_model_type( | |||||
| tenant_id=current_user.current_tenant_id, | |||||
| model_type=args['model_type'], | |||||
| provider_name=args['provider_name'], | |||||
| model_name=args['model_name'] | |||||
| ) | |||||
| return {'result': 'success'} | |||||
| class ValidModelApi(Resource): | |||||
| @setup_required | |||||
| @login_required | |||||
| @account_initialization_required | |||||
| def get(self, model_type): | |||||
| ModelType.value_of(model_type) | |||||
| provider_service = ProviderService() | |||||
| valid_models = provider_service.get_valid_model_list( | |||||
| tenant_id=current_user.current_tenant_id, | |||||
| model_type=model_type | |||||
| ) | |||||
| return valid_models | |||||
| api.add_resource(DefaultModelApi, '/workspaces/current/default-model') | |||||
| api.add_resource(ValidModelApi, '/workspaces/current/models/model-type/<string:model_type>') |
| # -*- coding:utf-8 -*- | |||||
| from flask_login import login_required, current_user | |||||
| from flask_restful import Resource, reqparse | |||||
| from werkzeug.exceptions import Forbidden | |||||
| from controllers.console import api | |||||
| from controllers.console.setup import setup_required | |||||
| from controllers.console.wraps import account_initialization_required | |||||
| from core.model_providers.providers.base import CredentialsValidateFailedError | |||||
| from models.provider import ProviderType | |||||
| from services.provider_service import ProviderService | |||||
| class ProviderListApi(Resource): | |||||
| @setup_required | |||||
| @login_required | |||||
| @account_initialization_required | |||||
| def get(self): | |||||
| tenant_id = current_user.current_tenant_id | |||||
| """ | |||||
| If the type is AZURE_OPENAI, decode and return the four fields of azure_api_type, azure_api_version:, | |||||
| azure_api_base, azure_api_key as an object, where azure_api_key displays the first 6 bits in plaintext, and the | |||||
| rest is replaced by * and the last two bits are displayed in plaintext | |||||
| If the type is other, decode and return the Token field directly, the field displays the first 6 bits in | |||||
| plaintext, the rest is replaced by * and the last two bits are displayed in plaintext | |||||
| """ | |||||
| provider_service = ProviderService() | |||||
| provider_info_list = provider_service.get_provider_list(tenant_id) | |||||
| provider_list = [ | |||||
| { | |||||
| 'provider_name': p['provider_name'], | |||||
| 'provider_type': p['provider_type'], | |||||
| 'is_valid': p['is_valid'], | |||||
| 'last_used': p['last_used'], | |||||
| 'is_enabled': p['is_valid'], | |||||
| **({ | |||||
| 'quota_type': p['quota_type'], | |||||
| 'quota_limit': p['quota_limit'], | |||||
| 'quota_used': p['quota_used'] | |||||
| } if p['provider_type'] == ProviderType.SYSTEM.value else {}), | |||||
| 'token': (p['config'] if p['provider_name'] != 'openai' else p['config']['openai_api_key']) | |||||
| if p['config'] else None | |||||
| } | |||||
| for name, provider_info in provider_info_list.items() | |||||
| for p in provider_info['providers'] | |||||
| ] | |||||
| return provider_list | |||||
| class ProviderTokenApi(Resource): | |||||
| @setup_required | |||||
| @login_required | |||||
| @account_initialization_required | |||||
| def post(self, provider): | |||||
| # The role of the current user in the ta table must be admin or owner | |||||
| if current_user.current_tenant.current_role not in ['admin', 'owner']: | |||||
| raise Forbidden() | |||||
| parser = reqparse.RequestParser() | |||||
| parser.add_argument('token', required=True, nullable=False, location='json') | |||||
| args = parser.parse_args() | |||||
| if provider == 'openai': | |||||
| args['token'] = { | |||||
| 'openai_api_key': args['token'] | |||||
| } | |||||
| provider_service = ProviderService() | |||||
| try: | |||||
| provider_service.save_custom_provider_config( | |||||
| tenant_id=current_user.current_tenant_id, | |||||
| provider_name=provider, | |||||
| config=args['token'] | |||||
| ) | |||||
| except CredentialsValidateFailedError as ex: | |||||
| raise ValueError(str(ex)) | |||||
| return {'result': 'success'}, 201 | |||||
| class ProviderTokenValidateApi(Resource): | |||||
| @setup_required | |||||
| @login_required | |||||
| @account_initialization_required | |||||
| def post(self, provider): | |||||
| parser = reqparse.RequestParser() | |||||
| parser.add_argument('token', required=True, nullable=False, location='json') | |||||
| args = parser.parse_args() | |||||
| provider_service = ProviderService() | |||||
| if provider == 'openai': | |||||
| args['token'] = { | |||||
| 'openai_api_key': args['token'] | |||||
| } | |||||
| result = True | |||||
| error = None | |||||
| try: | |||||
| provider_service.custom_provider_config_validate( | |||||
| provider_name=provider, | |||||
| config=args['token'] | |||||
| ) | |||||
| except CredentialsValidateFailedError as ex: | |||||
| result = False | |||||
| error = str(ex) | |||||
| response = {'result': 'success' if result else 'error'} | |||||
| if not result: | |||||
| response['error'] = error | |||||
| return response | |||||
| api.add_resource(ProviderTokenApi, '/workspaces/current/providers/<provider>/token', | |||||
| endpoint='workspaces_current_providers_token') # PUT for updating provider token | |||||
| api.add_resource(ProviderTokenValidateApi, '/workspaces/current/providers/<provider>/token-validate', | |||||
| endpoint='workspaces_current_providers_token_validate') # POST for validating provider token | |||||
| api.add_resource(ProviderListApi, '/workspaces/current/providers') # GET for getting providers list |
| 'created_at': TimestampField, | 'created_at': TimestampField, | ||||
| 'role': fields.String, | 'role': fields.String, | ||||
| 'providers': fields.List(fields.Nested(provider_fields)), | 'providers': fields.List(fields.Nested(provider_fields)), | ||||
| 'in_trail': fields.Boolean, | |||||
| 'in_trial': fields.Boolean, | |||||
| 'trial_end_reason': fields.String, | 'trial_end_reason': fields.String, | ||||
| } | } | ||||
| from controllers.service_api import api | from controllers.service_api import api | ||||
| from controllers.service_api.wraps import AppApiResource | from controllers.service_api.wraps import AppApiResource | ||||
| from core.llm.llm_builder import LLMBuilder | |||||
| from models.provider import ProviderName | |||||
| from models.model import App | from models.model import App | ||||
| def get(self, app_model: App, end_user): | def get(self, app_model: App, end_user): | ||||
| """Retrieve app parameters.""" | """Retrieve app parameters.""" | ||||
| app_model_config = app_model.app_model_config | app_model_config = app_model.app_model_config | ||||
| provider_name = LLMBuilder.get_default_provider(app_model.tenant_id, 'whisper-1') | |||||
| return { | return { | ||||
| 'opening_statement': app_model_config.opening_statement, | 'opening_statement': app_model_config.opening_statement, | ||||
| 'suggested_questions': app_model_config.suggested_questions_list, | 'suggested_questions': app_model_config.suggested_questions_list, | ||||
| 'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict, | 'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict, | ||||
| 'speech_to_text': app_model_config.speech_to_text_dict if provider_name == ProviderName.OPENAI.value else { 'enabled': False }, | |||||
| 'speech_to_text': app_model_config.speech_to_text_dict, | |||||
| 'more_like_this': app_model_config.more_like_this_dict, | 'more_like_this': app_model_config.more_like_this_dict, | ||||
| 'user_input_form': app_model_config.user_input_form_list | 'user_input_form': app_model_config.user_input_form_list | ||||
| } | } |
| ProviderModelCurrentlyNotSupportError, NoAudioUploadedError, AudioTooLargeError, UnsupportedAudioTypeError, \ | ProviderModelCurrentlyNotSupportError, NoAudioUploadedError, AudioTooLargeError, UnsupportedAudioTypeError, \ | ||||
| ProviderNotSupportSpeechToTextError | ProviderNotSupportSpeechToTextError | ||||
| from controllers.service_api.wraps import AppApiResource | from controllers.service_api.wraps import AppApiResource | ||||
| from core.llm.error import LLMBadRequestError, LLMAuthorizationError, LLMAPIUnavailableError, LLMAPIConnectionError, \ | |||||
| from core.model_providers.error import LLMBadRequestError, LLMAuthorizationError, LLMAPIUnavailableError, LLMAPIConnectionError, \ | |||||
| LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError | LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError | ||||
| from models.model import App, AppModelConfig | from models.model import App, AppModelConfig | ||||
| from services.audio_service import AudioService | from services.audio_service import AudioService |
| ProviderModelCurrentlyNotSupportError | ProviderModelCurrentlyNotSupportError | ||||
| from controllers.service_api.wraps import AppApiResource | from controllers.service_api.wraps import AppApiResource | ||||
| from core.conversation_message_task import PubHandler | from core.conversation_message_task import PubHandler | ||||
| from core.llm.error import LLMBadRequestError, LLMAuthorizationError, LLMAPIUnavailableError, LLMAPIConnectionError, \ | |||||
| from core.model_providers.error import LLMBadRequestError, LLMAuthorizationError, LLMAPIUnavailableError, LLMAPIConnectionError, \ | |||||
| LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError | LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError | ||||
| from libs.helper import uuid_value | from libs.helper import uuid_value | ||||
| from services.completion_service import CompletionService | from services.completion_service import CompletionService |
| from controllers.service_api.dataset.error import ArchivedDocumentImmutableError, DocumentIndexingError, \ | from controllers.service_api.dataset.error import ArchivedDocumentImmutableError, DocumentIndexingError, \ | ||||
| DatasetNotInitedError | DatasetNotInitedError | ||||
| from controllers.service_api.wraps import DatasetApiResource | from controllers.service_api.wraps import DatasetApiResource | ||||
| from core.llm.error import ProviderTokenNotInitError | |||||
| from core.model_providers.error import ProviderTokenNotInitError | |||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| from extensions.ext_storage import storage | from extensions.ext_storage import storage | ||||
| from models.model import UploadFile | from models.model import UploadFile |
| from controllers.web import api | from controllers.web import api | ||||
| from controllers.web.wraps import WebApiResource | from controllers.web.wraps import WebApiResource | ||||
| from core.llm.llm_builder import LLMBuilder | |||||
| from models.provider import ProviderName | |||||
| from models.model import App | from models.model import App | ||||
| def get(self, app_model: App, end_user): | def get(self, app_model: App, end_user): | ||||
| """Retrieve app parameters.""" | """Retrieve app parameters.""" | ||||
| app_model_config = app_model.app_model_config | app_model_config = app_model.app_model_config | ||||
| provider_name = LLMBuilder.get_default_provider(app_model.tenant_id, 'whisper-1') | |||||
| return { | return { | ||||
| 'opening_statement': app_model_config.opening_statement, | 'opening_statement': app_model_config.opening_statement, | ||||
| 'suggested_questions': app_model_config.suggested_questions_list, | 'suggested_questions': app_model_config.suggested_questions_list, | ||||
| 'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict, | 'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict, | ||||
| 'speech_to_text': app_model_config.speech_to_text_dict if provider_name == ProviderName.OPENAI.value else { 'enabled': False }, | |||||
| 'speech_to_text': app_model_config.speech_to_text_dict, | |||||
| 'more_like_this': app_model_config.more_like_this_dict, | 'more_like_this': app_model_config.more_like_this_dict, | ||||
| 'user_input_form': app_model_config.user_input_form_list | 'user_input_form': app_model_config.user_input_form_list | ||||
| } | } |
| ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, NoAudioUploadedError, AudioTooLargeError, \ | ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, NoAudioUploadedError, AudioTooLargeError, \ | ||||
| UnsupportedAudioTypeError, ProviderNotSupportSpeechToTextError | UnsupportedAudioTypeError, ProviderNotSupportSpeechToTextError | ||||
| from controllers.web.wraps import WebApiResource | from controllers.web.wraps import WebApiResource | ||||
| from core.llm.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \ | |||||
| from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \ | |||||
| LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError | LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError | ||||
| from services.audio_service import AudioService | from services.audio_service import AudioService | ||||
| from services.errors.audio import NoAudioUploadedServiceError, AudioTooLargeServiceError, \ | from services.errors.audio import NoAudioUploadedServiceError, AudioTooLargeServiceError, \ |
| ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError | ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError | ||||
| from controllers.web.wraps import WebApiResource | from controllers.web.wraps import WebApiResource | ||||
| from core.conversation_message_task import PubHandler | from core.conversation_message_task import PubHandler | ||||
| from core.llm.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \ | |||||
| from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \ | |||||
| LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError | LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError | ||||
| from libs.helper import uuid_value | from libs.helper import uuid_value | ||||
| from services.completion_service import CompletionService | from services.completion_service import CompletionService |
| AppMoreLikeThisDisabledError, NotCompletionAppError, AppSuggestedQuestionsAfterAnswerDisabledError, \ | AppMoreLikeThisDisabledError, NotCompletionAppError, AppSuggestedQuestionsAfterAnswerDisabledError, \ | ||||
| ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError | ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError | ||||
| from controllers.web.wraps import WebApiResource | from controllers.web.wraps import WebApiResource | ||||
| from core.llm.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \ | |||||
| from core.model_providers.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \ | |||||
| ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError | ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError | ||||
| from libs.helper import uuid_value, TimestampField | from libs.helper import uuid_value, TimestampField | ||||
| from services.completion_service import CompletionService | from services.completion_service import CompletionService |
| import os | |||||
| from typing import Optional | |||||
| import langchain | |||||
| from flask import Flask | |||||
| from pydantic import BaseModel | |||||
| from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler | |||||
| from core.prompt.prompt_template import OneLineFormatter | |||||
| class HostedOpenAICredential(BaseModel): | |||||
| api_key: str | |||||
| class HostedAnthropicCredential(BaseModel): | |||||
| api_key: str | |||||
| class HostedLLMCredentials(BaseModel): | |||||
| openai: Optional[HostedOpenAICredential] = None | |||||
| anthropic: Optional[HostedAnthropicCredential] = None | |||||
| hosted_llm_credentials = HostedLLMCredentials() | |||||
| def init_app(app: Flask): | |||||
| if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true': | |||||
| langchain.verbose = True | |||||
| if app.config.get("OPENAI_API_KEY"): | |||||
| hosted_llm_credentials.openai = HostedOpenAICredential(api_key=app.config.get("OPENAI_API_KEY")) | |||||
| if app.config.get("ANTHROPIC_API_KEY"): | |||||
| hosted_llm_credentials.anthropic = HostedAnthropicCredential(api_key=app.config.get("ANTHROPIC_API_KEY")) |
| from typing import cast, List | |||||
| from typing import List | |||||
| from langchain import OpenAI | |||||
| from langchain.base_language import BaseLanguageModel | |||||
| from langchain.chat_models.openai import ChatOpenAI | |||||
| from langchain.schema import BaseMessage | from langchain.schema import BaseMessage | ||||
| from core.constant import llm_constant | |||||
| from core.model_providers.models.entity.message import to_prompt_messages | |||||
| from core.model_providers.models.llm.base import BaseLLM | |||||
| class CalcTokenMixin: | class CalcTokenMixin: | ||||
| def get_num_tokens_from_messages(self, llm: BaseLanguageModel, messages: List[BaseMessage], **kwargs) -> int: | |||||
| llm = cast(ChatOpenAI, llm) | |||||
| return llm.get_num_tokens_from_messages(messages) | |||||
| def get_num_tokens_from_messages(self, model_instance: BaseLLM, messages: List[BaseMessage], **kwargs) -> int: | |||||
| return model_instance.get_num_tokens(to_prompt_messages(messages)) | |||||
| def get_message_rest_tokens(self, llm: BaseLanguageModel, messages: List[BaseMessage], **kwargs) -> int: | |||||
| def get_message_rest_tokens(self, model_instance: BaseLLM, messages: List[BaseMessage], **kwargs) -> int: | |||||
| """ | """ | ||||
| Got the rest tokens available for the model after excluding messages tokens and completion max tokens | Got the rest tokens available for the model after excluding messages tokens and completion max tokens | ||||
| :param messages: | :param messages: | ||||
| :return: | :return: | ||||
| """ | """ | ||||
| llm = cast(ChatOpenAI, llm) | |||||
| llm_max_tokens = llm_constant.max_context_token_length[llm.model_name] | |||||
| completion_max_tokens = llm.max_tokens | |||||
| used_tokens = self.get_num_tokens_from_messages(llm, messages, **kwargs) | |||||
| llm_max_tokens = model_instance.model_rules.max_tokens.max | |||||
| completion_max_tokens = model_instance.model_kwargs.max_tokens | |||||
| used_tokens = self.get_num_tokens_from_messages(model_instance, messages, **kwargs) | |||||
| rest_tokens = llm_max_tokens - completion_max_tokens - used_tokens | rest_tokens = llm_max_tokens - completion_max_tokens - used_tokens | ||||
| return rest_tokens | return rest_tokens |
| from langchain.callbacks.base import BaseCallbackManager | from langchain.callbacks.base import BaseCallbackManager | ||||
| from langchain.callbacks.manager import Callbacks | from langchain.callbacks.manager import Callbacks | ||||
| from langchain.prompts.chat import BaseMessagePromptTemplate | from langchain.prompts.chat import BaseMessagePromptTemplate | ||||
| from langchain.schema import AgentAction, AgentFinish, BaseLanguageModel, SystemMessage | |||||
| from langchain.schema import AgentAction, AgentFinish, SystemMessage | |||||
| from langchain.schema.language_model import BaseLanguageModel | |||||
| from langchain.tools import BaseTool | from langchain.tools import BaseTool | ||||
| from core.model_providers.models.llm.base import BaseLLM | |||||
| from core.tool.dataset_retriever_tool import DatasetRetrieverTool | from core.tool.dataset_retriever_tool import DatasetRetrieverTool | ||||
| """ | """ | ||||
| An Multi Dataset Retrieve Agent driven by Router. | An Multi Dataset Retrieve Agent driven by Router. | ||||
| """ | """ | ||||
| model_instance: BaseLLM | |||||
| class Config: | |||||
| """Configuration for this pydantic object.""" | |||||
| arbitrary_types_allowed = True | |||||
| def should_use_agent(self, query: str): | def should_use_agent(self, query: str): | ||||
| """ | """ |
| from langchain.callbacks.base import BaseCallbackManager | from langchain.callbacks.base import BaseCallbackManager | ||||
| from langchain.callbacks.manager import Callbacks | from langchain.callbacks.manager import Callbacks | ||||
| from langchain.prompts.chat import BaseMessagePromptTemplate | from langchain.prompts.chat import BaseMessagePromptTemplate | ||||
| from langchain.schema import AgentAction, AgentFinish, SystemMessage, BaseLanguageModel | |||||
| from langchain.schema import AgentAction, AgentFinish, SystemMessage | |||||
| from langchain.schema.language_model import BaseLanguageModel | |||||
| from langchain.tools import BaseTool | from langchain.tools import BaseTool | ||||
| from core.agent.agent.calc_token_mixin import ExceededLLMTokensLimitError | from core.agent.agent.calc_token_mixin import ExceededLLMTokensLimitError | ||||
| # summarize messages if rest_tokens < 0 | # summarize messages if rest_tokens < 0 | ||||
| try: | try: | ||||
| messages = self.summarize_messages_if_needed(self.llm, messages, functions=self.functions) | |||||
| messages = self.summarize_messages_if_needed(messages, functions=self.functions) | |||||
| except ExceededLLMTokensLimitError as e: | except ExceededLLMTokensLimitError as e: | ||||
| return AgentFinish(return_values={"output": str(e)}, log=str(e)) | return AgentFinish(return_values={"output": str(e)}, log=str(e)) | ||||
| from langchain.chat_models import ChatOpenAI | from langchain.chat_models import ChatOpenAI | ||||
| from langchain.chat_models.openai import _convert_message_to_dict | from langchain.chat_models.openai import _convert_message_to_dict | ||||
| from langchain.memory.summary import SummarizerMixin | from langchain.memory.summary import SummarizerMixin | ||||
| from langchain.schema import SystemMessage, HumanMessage, BaseMessage, AIMessage, BaseLanguageModel | |||||
| from langchain.schema import SystemMessage, HumanMessage, BaseMessage, AIMessage | |||||
| from langchain.schema.language_model import BaseLanguageModel | |||||
| from pydantic import BaseModel | from pydantic import BaseModel | ||||
| from core.agent.agent.calc_token_mixin import ExceededLLMTokensLimitError, CalcTokenMixin | from core.agent.agent.calc_token_mixin import ExceededLLMTokensLimitError, CalcTokenMixin | ||||
| from core.model_providers.models.llm.base import BaseLLM | |||||
| class OpenAIFunctionCallSummarizeMixin(BaseModel, CalcTokenMixin): | class OpenAIFunctionCallSummarizeMixin(BaseModel, CalcTokenMixin): | ||||
| moving_summary_buffer: str = "" | moving_summary_buffer: str = "" | ||||
| moving_summary_index: int = 0 | moving_summary_index: int = 0 | ||||
| summary_llm: BaseLanguageModel | summary_llm: BaseLanguageModel | ||||
| model_instance: BaseLLM | |||||
| def summarize_messages_if_needed(self, llm: BaseLanguageModel, messages: List[BaseMessage], **kwargs) -> List[BaseMessage]: | |||||
| class Config: | |||||
| """Configuration for this pydantic object.""" | |||||
| arbitrary_types_allowed = True | |||||
| def summarize_messages_if_needed(self, messages: List[BaseMessage], **kwargs) -> List[BaseMessage]: | |||||
| # calculate rest tokens and summarize previous function observation messages if rest_tokens < 0 | # calculate rest tokens and summarize previous function observation messages if rest_tokens < 0 | ||||
| rest_tokens = self.get_message_rest_tokens(llm, messages, **kwargs) | |||||
| rest_tokens = self.get_message_rest_tokens(self.model_instance, messages, **kwargs) | |||||
| rest_tokens = rest_tokens - 20 # to deal with the inaccuracy of rest_tokens | rest_tokens = rest_tokens - 20 # to deal with the inaccuracy of rest_tokens | ||||
| if rest_tokens >= 0: | if rest_tokens >= 0: | ||||
| return messages | return messages |
| from langchain.callbacks.base import BaseCallbackManager | from langchain.callbacks.base import BaseCallbackManager | ||||
| from langchain.callbacks.manager import Callbacks | from langchain.callbacks.manager import Callbacks | ||||
| from langchain.prompts.chat import BaseMessagePromptTemplate | from langchain.prompts.chat import BaseMessagePromptTemplate | ||||
| from langchain.schema import AgentAction, AgentFinish, SystemMessage, BaseLanguageModel | |||||
| from langchain.schema import AgentAction, AgentFinish, SystemMessage | |||||
| from langchain.schema.language_model import BaseLanguageModel | |||||
| from langchain.tools import BaseTool | from langchain.tools import BaseTool | ||||
| from core.agent.agent.calc_token_mixin import ExceededLLMTokensLimitError | from core.agent.agent.calc_token_mixin import ExceededLLMTokensLimitError | ||||
| # summarize messages if rest_tokens < 0 | # summarize messages if rest_tokens < 0 | ||||
| try: | try: | ||||
| messages = self.summarize_messages_if_needed(self.llm, messages, functions=self.functions) | |||||
| messages = self.summarize_messages_if_needed(messages, functions=self.functions) | |||||
| except ExceededLLMTokensLimitError as e: | except ExceededLLMTokensLimitError as e: | ||||
| return AgentFinish(return_values={"output": str(e)}, log=str(e)) | return AgentFinish(return_values={"output": str(e)}, log=str(e)) | ||||
| import re | |||||
| from typing import List, Tuple, Any, Union, Sequence, Optional, cast | |||||
| from langchain import BasePromptTemplate | |||||
| from langchain.agents import StructuredChatAgent, AgentOutputParser, Agent | |||||
| from langchain.agents.structured_chat.base import HUMAN_MESSAGE_TEMPLATE | |||||
| from langchain.base_language import BaseLanguageModel | |||||
| from langchain.callbacks.base import BaseCallbackManager | |||||
| from langchain.callbacks.manager import Callbacks | |||||
| from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate | |||||
| from langchain.schema import AgentAction, AgentFinish, OutputParserException | |||||
| from langchain.tools import BaseTool | |||||
| from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX | |||||
| from core.model_providers.models.llm.base import BaseLLM | |||||
| from core.tool.dataset_retriever_tool import DatasetRetrieverTool | |||||
| FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input). | |||||
| The nouns in the format of "Thought", "Action", "Action Input", "Final Answer" must be expressed in English. | |||||
| Valid "action" values: "Final Answer" or {tool_names} | |||||
| Provide only ONE action per $JSON_BLOB, as shown: | |||||
| ``` | |||||
| {{{{ | |||||
| "action": $TOOL_NAME, | |||||
| "action_input": $INPUT | |||||
| }}}} | |||||
| ``` | |||||
| Follow this format: | |||||
| Question: input question to answer | |||||
| Thought: consider previous and subsequent steps | |||||
| Action: | |||||
| ``` | |||||
| $JSON_BLOB | |||||
| ``` | |||||
| Observation: action result | |||||
| ... (repeat Thought/Action/Observation N times) | |||||
| Thought: I know what to respond | |||||
| Action: | |||||
| ``` | |||||
| {{{{ | |||||
| "action": "Final Answer", | |||||
| "action_input": "Final response to human" | |||||
| }}}} | |||||
| ```""" | |||||
| class StructuredMultiDatasetRouterAgent(StructuredChatAgent): | |||||
| model_instance: BaseLLM | |||||
| dataset_tools: Sequence[BaseTool] | |||||
| class Config: | |||||
| """Configuration for this pydantic object.""" | |||||
| arbitrary_types_allowed = True | |||||
| def should_use_agent(self, query: str): | |||||
| """ | |||||
| return should use agent | |||||
| Using the ReACT mode to determine whether an agent is needed is costly, | |||||
| so it's better to just use an Agent for reasoning, which is cheaper. | |||||
| :param query: | |||||
| :return: | |||||
| """ | |||||
| return True | |||||
| def plan( | |||||
| self, | |||||
| intermediate_steps: List[Tuple[AgentAction, str]], | |||||
| callbacks: Callbacks = None, | |||||
| **kwargs: Any, | |||||
| ) -> Union[AgentAction, AgentFinish]: | |||||
| """Given input, decided what to do. | |||||
| Args: | |||||
| intermediate_steps: Steps the LLM has taken to date, | |||||
| along with observations | |||||
| callbacks: Callbacks to run. | |||||
| **kwargs: User inputs. | |||||
| Returns: | |||||
| Action specifying what tool to use. | |||||
| """ | |||||
| if len(self.dataset_tools) == 0: | |||||
| return AgentFinish(return_values={"output": ''}, log='') | |||||
| elif len(self.dataset_tools) == 1: | |||||
| tool = next(iter(self.dataset_tools)) | |||||
| tool = cast(DatasetRetrieverTool, tool) | |||||
| rst = tool.run(tool_input={'dataset_id': tool.dataset_id, 'query': kwargs['input']}) | |||||
| return AgentFinish(return_values={"output": rst}, log=rst) | |||||
| full_inputs = self.get_full_inputs(intermediate_steps, **kwargs) | |||||
| full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs) | |||||
| try: | |||||
| return self.output_parser.parse(full_output) | |||||
| except OutputParserException: | |||||
| return AgentFinish({"output": "I'm sorry, the answer of model is invalid, " | |||||
| "I don't know how to respond to that."}, "") | |||||
| @classmethod | |||||
| def create_prompt( | |||||
| cls, | |||||
| tools: Sequence[BaseTool], | |||||
| prefix: str = PREFIX, | |||||
| suffix: str = SUFFIX, | |||||
| human_message_template: str = HUMAN_MESSAGE_TEMPLATE, | |||||
| format_instructions: str = FORMAT_INSTRUCTIONS, | |||||
| input_variables: Optional[List[str]] = None, | |||||
| memory_prompts: Optional[List[BasePromptTemplate]] = None, | |||||
| ) -> BasePromptTemplate: | |||||
| tool_strings = [] | |||||
| for tool in tools: | |||||
| args_schema = re.sub("}", "}}}}", re.sub("{", "{{{{", str(tool.args))) | |||||
| tool_strings.append(f"{tool.name}: {tool.description}, args: {args_schema}") | |||||
| formatted_tools = "\n".join(tool_strings) | |||||
| unique_tool_names = set(tool.name for tool in tools) | |||||
| tool_names = ", ".join('"' + name + '"' for name in unique_tool_names) | |||||
| format_instructions = format_instructions.format(tool_names=tool_names) | |||||
| template = "\n\n".join([prefix, formatted_tools, format_instructions, suffix]) | |||||
| if input_variables is None: | |||||
| input_variables = ["input", "agent_scratchpad"] | |||||
| _memory_prompts = memory_prompts or [] | |||||
| messages = [ | |||||
| SystemMessagePromptTemplate.from_template(template), | |||||
| *_memory_prompts, | |||||
| HumanMessagePromptTemplate.from_template(human_message_template), | |||||
| ] | |||||
| return ChatPromptTemplate(input_variables=input_variables, messages=messages) | |||||
| @classmethod | |||||
| def from_llm_and_tools( | |||||
| cls, | |||||
| llm: BaseLanguageModel, | |||||
| tools: Sequence[BaseTool], | |||||
| callback_manager: Optional[BaseCallbackManager] = None, | |||||
| output_parser: Optional[AgentOutputParser] = None, | |||||
| prefix: str = PREFIX, | |||||
| suffix: str = SUFFIX, | |||||
| human_message_template: str = HUMAN_MESSAGE_TEMPLATE, | |||||
| format_instructions: str = FORMAT_INSTRUCTIONS, | |||||
| input_variables: Optional[List[str]] = None, | |||||
| memory_prompts: Optional[List[BasePromptTemplate]] = None, | |||||
| **kwargs: Any, | |||||
| ) -> Agent: | |||||
| return super().from_llm_and_tools( | |||||
| llm=llm, | |||||
| tools=tools, | |||||
| callback_manager=callback_manager, | |||||
| output_parser=output_parser, | |||||
| prefix=prefix, | |||||
| suffix=suffix, | |||||
| human_message_template=human_message_template, | |||||
| format_instructions=format_instructions, | |||||
| input_variables=input_variables, | |||||
| memory_prompts=memory_prompts, | |||||
| dataset_tools=tools, | |||||
| **kwargs, | |||||
| ) |
| from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX | from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX | ||||
| from core.agent.agent.calc_token_mixin import CalcTokenMixin, ExceededLLMTokensLimitError | from core.agent.agent.calc_token_mixin import CalcTokenMixin, ExceededLLMTokensLimitError | ||||
| from core.model_providers.models.llm.base import BaseLLM | |||||
| FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input). | FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input). | ||||
| The nouns in the format of "Thought", "Action", "Action Input", "Final Answer" must be expressed in English. | The nouns in the format of "Thought", "Action", "Action Input", "Final Answer" must be expressed in English. | ||||
| moving_summary_buffer: str = "" | moving_summary_buffer: str = "" | ||||
| moving_summary_index: int = 0 | moving_summary_index: int = 0 | ||||
| summary_llm: BaseLanguageModel | summary_llm: BaseLanguageModel | ||||
| model_instance: BaseLLM | |||||
| class Config: | |||||
| """Configuration for this pydantic object.""" | |||||
| arbitrary_types_allowed = True | |||||
| def should_use_agent(self, query: str): | def should_use_agent(self, query: str): | ||||
| """ | """ | ||||
| if prompts: | if prompts: | ||||
| messages = prompts[0].to_messages() | messages = prompts[0].to_messages() | ||||
| rest_tokens = self.get_message_rest_tokens(self.llm_chain.llm, messages) | |||||
| rest_tokens = self.get_message_rest_tokens(self.model_instance, messages) | |||||
| if rest_tokens < 0: | if rest_tokens < 0: | ||||
| full_inputs = self.summarize_messages(intermediate_steps, **kwargs) | full_inputs = self.summarize_messages(intermediate_steps, **kwargs) | ||||
| from typing import Union, Optional | from typing import Union, Optional | ||||
| from langchain.agents import BaseSingleActionAgent, BaseMultiActionAgent | from langchain.agents import BaseSingleActionAgent, BaseMultiActionAgent | ||||
| from langchain.base_language import BaseLanguageModel | |||||
| from langchain.callbacks.manager import Callbacks | from langchain.callbacks.manager import Callbacks | ||||
| from langchain.memory.chat_memory import BaseChatMemory | from langchain.memory.chat_memory import BaseChatMemory | ||||
| from langchain.tools import BaseTool | from langchain.tools import BaseTool | ||||
| from core.agent.agent.openai_function_call import AutoSummarizingOpenAIFunctionCallAgent | from core.agent.agent.openai_function_call import AutoSummarizingOpenAIFunctionCallAgent | ||||
| from core.agent.agent.openai_multi_function_call import AutoSummarizingOpenMultiAIFunctionCallAgent | from core.agent.agent.openai_multi_function_call import AutoSummarizingOpenMultiAIFunctionCallAgent | ||||
| from core.agent.agent.output_parser.structured_chat import StructuredChatOutputParser | from core.agent.agent.output_parser.structured_chat import StructuredChatOutputParser | ||||
| from core.agent.agent.structed_multi_dataset_router_agent import StructuredMultiDatasetRouterAgent | |||||
| from core.agent.agent.structured_chat import AutoSummarizingStructuredChatAgent | from core.agent.agent.structured_chat import AutoSummarizingStructuredChatAgent | ||||
| from langchain.agents import AgentExecutor as LCAgentExecutor | from langchain.agents import AgentExecutor as LCAgentExecutor | ||||
| from core.model_providers.models.llm.base import BaseLLM | |||||
| from core.tool.dataset_retriever_tool import DatasetRetrieverTool | from core.tool.dataset_retriever_tool import DatasetRetrieverTool | ||||
| class PlanningStrategy(str, enum.Enum): | class PlanningStrategy(str, enum.Enum): | ||||
| ROUTER = 'router' | ROUTER = 'router' | ||||
| REACT_ROUTER = 'react_router' | |||||
| REACT = 'react' | REACT = 'react' | ||||
| FUNCTION_CALL = 'function_call' | FUNCTION_CALL = 'function_call' | ||||
| MULTI_FUNCTION_CALL = 'multi_function_call' | MULTI_FUNCTION_CALL = 'multi_function_call' | ||||
| class AgentConfiguration(BaseModel): | class AgentConfiguration(BaseModel): | ||||
| strategy: PlanningStrategy | strategy: PlanningStrategy | ||||
| llm: BaseLanguageModel | |||||
| model_instance: BaseLLM | |||||
| tools: list[BaseTool] | tools: list[BaseTool] | ||||
| summary_llm: BaseLanguageModel | |||||
| dataset_llm: BaseLanguageModel | |||||
| summary_model_instance: BaseLLM | |||||
| memory: Optional[BaseChatMemory] = None | memory: Optional[BaseChatMemory] = None | ||||
| callbacks: Callbacks = None | callbacks: Callbacks = None | ||||
| max_iterations: int = 6 | max_iterations: int = 6 | ||||
| def _init_agent(self) -> Union[BaseSingleActionAgent | BaseMultiActionAgent]: | def _init_agent(self) -> Union[BaseSingleActionAgent | BaseMultiActionAgent]: | ||||
| if self.configuration.strategy == PlanningStrategy.REACT: | if self.configuration.strategy == PlanningStrategy.REACT: | ||||
| agent = AutoSummarizingStructuredChatAgent.from_llm_and_tools( | agent = AutoSummarizingStructuredChatAgent.from_llm_and_tools( | ||||
| llm=self.configuration.llm, | |||||
| model_instance=self.configuration.model_instance, | |||||
| llm=self.configuration.model_instance.client, | |||||
| tools=self.configuration.tools, | tools=self.configuration.tools, | ||||
| output_parser=StructuredChatOutputParser(), | output_parser=StructuredChatOutputParser(), | ||||
| summary_llm=self.configuration.summary_llm, | |||||
| summary_llm=self.configuration.summary_model_instance.client, | |||||
| verbose=True | verbose=True | ||||
| ) | ) | ||||
| elif self.configuration.strategy == PlanningStrategy.FUNCTION_CALL: | elif self.configuration.strategy == PlanningStrategy.FUNCTION_CALL: | ||||
| agent = AutoSummarizingOpenAIFunctionCallAgent.from_llm_and_tools( | agent = AutoSummarizingOpenAIFunctionCallAgent.from_llm_and_tools( | ||||
| llm=self.configuration.llm, | |||||
| model_instance=self.configuration.model_instance, | |||||
| llm=self.configuration.model_instance.client, | |||||
| tools=self.configuration.tools, | tools=self.configuration.tools, | ||||
| extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None, # used for read chat histories memory | extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None, # used for read chat histories memory | ||||
| summary_llm=self.configuration.summary_llm, | |||||
| summary_llm=self.configuration.summary_model_instance.client, | |||||
| verbose=True | verbose=True | ||||
| ) | ) | ||||
| elif self.configuration.strategy == PlanningStrategy.MULTI_FUNCTION_CALL: | elif self.configuration.strategy == PlanningStrategy.MULTI_FUNCTION_CALL: | ||||
| agent = AutoSummarizingOpenMultiAIFunctionCallAgent.from_llm_and_tools( | agent = AutoSummarizingOpenMultiAIFunctionCallAgent.from_llm_and_tools( | ||||
| llm=self.configuration.llm, | |||||
| model_instance=self.configuration.model_instance, | |||||
| llm=self.configuration.model_instance.client, | |||||
| tools=self.configuration.tools, | tools=self.configuration.tools, | ||||
| extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None, # used for read chat histories memory | extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None, # used for read chat histories memory | ||||
| summary_llm=self.configuration.summary_llm, | |||||
| summary_llm=self.configuration.summary_model_instance.client, | |||||
| verbose=True | verbose=True | ||||
| ) | ) | ||||
| elif self.configuration.strategy == PlanningStrategy.ROUTER: | elif self.configuration.strategy == PlanningStrategy.ROUTER: | ||||
| self.configuration.tools = [t for t in self.configuration.tools if isinstance(t, DatasetRetrieverTool)] | self.configuration.tools = [t for t in self.configuration.tools if isinstance(t, DatasetRetrieverTool)] | ||||
| agent = MultiDatasetRouterAgent.from_llm_and_tools( | agent = MultiDatasetRouterAgent.from_llm_and_tools( | ||||
| llm=self.configuration.dataset_llm, | |||||
| model_instance=self.configuration.model_instance, | |||||
| llm=self.configuration.model_instance.client, | |||||
| tools=self.configuration.tools, | tools=self.configuration.tools, | ||||
| extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None, | extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None, | ||||
| verbose=True | verbose=True | ||||
| ) | ) | ||||
| elif self.configuration.strategy == PlanningStrategy.REACT_ROUTER: | |||||
| self.configuration.tools = [t for t in self.configuration.tools if isinstance(t, DatasetRetrieverTool)] | |||||
| agent = StructuredMultiDatasetRouterAgent.from_llm_and_tools( | |||||
| model_instance=self.configuration.model_instance, | |||||
| llm=self.configuration.model_instance.client, | |||||
| tools=self.configuration.tools, | |||||
| output_parser=StructuredChatOutputParser(), | |||||
| verbose=True | |||||
| ) | |||||
| else: | else: | ||||
| raise NotImplementedError(f"Unknown Agent Strategy: {self.configuration.strategy}") | raise NotImplementedError(f"Unknown Agent Strategy: {self.configuration.strategy}") | ||||
| from core.callback_handler.entity.agent_loop import AgentLoop | from core.callback_handler.entity.agent_loop import AgentLoop | ||||
| from core.conversation_message_task import ConversationMessageTask | from core.conversation_message_task import ConversationMessageTask | ||||
| from core.model_providers.models.llm.base import BaseLLM | |||||
| class AgentLoopGatherCallbackHandler(BaseCallbackHandler): | class AgentLoopGatherCallbackHandler(BaseCallbackHandler): | ||||
| """Callback Handler that prints to std out.""" | """Callback Handler that prints to std out.""" | ||||
| raise_error: bool = True | raise_error: bool = True | ||||
| def __init__(self, model_name, conversation_message_task: ConversationMessageTask) -> None: | |||||
| def __init__(self, model_instant: BaseLLM, conversation_message_task: ConversationMessageTask) -> None: | |||||
| """Initialize callback handler.""" | """Initialize callback handler.""" | ||||
| self.model_name = model_name | |||||
| self.model_instant = model_instant | |||||
| self.conversation_message_task = conversation_message_task | self.conversation_message_task = conversation_message_task | ||||
| self._agent_loops = [] | self._agent_loops = [] | ||||
| self._current_loop = None | self._current_loop = None | ||||
| self._current_loop.latency = self._current_loop.completed_at - self._current_loop.started_at | self._current_loop.latency = self._current_loop.completed_at - self._current_loop.started_at | ||||
| self.conversation_message_task.on_agent_end( | self.conversation_message_task.on_agent_end( | ||||
| self._message_agent_thought, self.model_name, self._current_loop | |||||
| self._message_agent_thought, self.model_instant, self._current_loop | |||||
| ) | ) | ||||
| self._agent_loops.append(self._current_loop) | self._agent_loops.append(self._current_loop) | ||||
| ) | ) | ||||
| self.conversation_message_task.on_agent_end( | self.conversation_message_task.on_agent_end( | ||||
| self._message_agent_thought, self.model_name, self._current_loop | |||||
| self._message_agent_thought, self.model_instant, self._current_loop | |||||
| ) | ) | ||||
| self._agent_loops.append(self._current_loop) | self._agent_loops.append(self._current_loop) |
| from typing import Any, Dict, List, Union | from typing import Any, Dict, List, Union | ||||
| from langchain.callbacks.base import BaseCallbackHandler | from langchain.callbacks.base import BaseCallbackHandler | ||||
| from langchain.schema import LLMResult, BaseMessage, BaseLanguageModel | |||||
| from langchain.schema import LLMResult, BaseMessage | |||||
| from core.callback_handler.entity.llm_message import LLMMessage | from core.callback_handler.entity.llm_message import LLMMessage | ||||
| from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException | from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException | ||||
| from core.model_providers.models.entity.message import to_prompt_messages, PromptMessage | |||||
| from core.model_providers.models.llm.base import BaseLLM | |||||
| class LLMCallbackHandler(BaseCallbackHandler): | class LLMCallbackHandler(BaseCallbackHandler): | ||||
| raise_error: bool = True | raise_error: bool = True | ||||
| def __init__(self, llm: BaseLanguageModel, | |||||
| def __init__(self, model_instance: BaseLLM, | |||||
| conversation_message_task: ConversationMessageTask): | conversation_message_task: ConversationMessageTask): | ||||
| self.llm = llm | |||||
| self.model_instance = model_instance | |||||
| self.llm_message = LLMMessage() | self.llm_message = LLMMessage() | ||||
| self.start_at = None | self.start_at = None | ||||
| self.conversation_message_task = conversation_message_task | self.conversation_message_task = conversation_message_task | ||||
| }) | }) | ||||
| self.llm_message.prompt = real_prompts | self.llm_message.prompt = real_prompts | ||||
| self.llm_message.prompt_tokens = self.llm.get_num_tokens_from_messages(messages[0]) | |||||
| self.llm_message.prompt_tokens = self.model_instance.get_num_tokens(to_prompt_messages(messages[0])) | |||||
| def on_llm_start( | def on_llm_start( | ||||
| self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any | self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any | ||||
| "text": prompts[0] | "text": prompts[0] | ||||
| }] | }] | ||||
| self.llm_message.prompt_tokens = self.llm.get_num_tokens(prompts[0]) | |||||
| self.llm_message.prompt_tokens = self.model_instance.get_num_tokens([PromptMessage(content=prompts[0])]) | |||||
| def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: | def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: | ||||
| end_at = time.perf_counter() | end_at = time.perf_counter() | ||||
| self.conversation_message_task.append_message_text(response.generations[0][0].text) | self.conversation_message_task.append_message_text(response.generations[0][0].text) | ||||
| self.llm_message.completion = response.generations[0][0].text | self.llm_message.completion = response.generations[0][0].text | ||||
| self.llm_message.completion_tokens = self.llm.get_num_tokens(self.llm_message.completion) | |||||
| self.llm_message.completion_tokens = self.model_instance.get_num_tokens([PromptMessage(content=self.llm_message.completion)]) | |||||
| self.conversation_message_task.save_message(self.llm_message) | self.conversation_message_task.save_message(self.llm_message) | ||||
| if self.conversation_message_task.streaming: | if self.conversation_message_task.streaming: | ||||
| end_at = time.perf_counter() | end_at = time.perf_counter() | ||||
| self.llm_message.latency = end_at - self.start_at | self.llm_message.latency = end_at - self.start_at | ||||
| self.llm_message.completion_tokens = self.llm.get_num_tokens(self.llm_message.completion) | |||||
| self.llm_message.completion_tokens = self.model_instance.get_num_tokens( | |||||
| [PromptMessage(content=self.llm_message.completion)] | |||||
| ) | |||||
| self.conversation_message_task.save_message(llm_message=self.llm_message, by_stopped=True) | self.conversation_message_task.save_message(llm_message=self.llm_message, by_stopped=True) | ||||
| else: | else: | ||||
| logging.error(error) | logging.error(error) |
| from langchain.callbacks.base import BaseCallbackHandler | from langchain.callbacks.base import BaseCallbackHandler | ||||
| from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler | |||||
| from core.callback_handler.entity.chain_result import ChainResult | from core.callback_handler.entity.chain_result import ChainResult | ||||
| from core.constant import llm_constant | |||||
| from core.conversation_message_task import ConversationMessageTask | from core.conversation_message_task import ConversationMessageTask | ||||
| import re | import re | ||||
| from typing import Optional, List, Union, Tuple | from typing import Optional, List, Union, Tuple | ||||
| from langchain.base_language import BaseLanguageModel | |||||
| from langchain.callbacks.base import BaseCallbackHandler | |||||
| from langchain.chat_models.base import BaseChatModel | |||||
| from langchain.llms import BaseLLM | |||||
| from langchain.schema import BaseMessage, HumanMessage | |||||
| from langchain.schema import BaseMessage | |||||
| from requests.exceptions import ChunkedEncodingError | from requests.exceptions import ChunkedEncodingError | ||||
| from core.agent.agent_executor import AgentExecuteResult, PlanningStrategy | from core.agent.agent_executor import AgentExecuteResult, PlanningStrategy | ||||
| from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler | from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler | ||||
| from core.constant import llm_constant | |||||
| from core.callback_handler.llm_callback_handler import LLMCallbackHandler | from core.callback_handler.llm_callback_handler import LLMCallbackHandler | ||||
| from core.callback_handler.std_out_callback_handler import DifyStreamingStdOutCallbackHandler, \ | |||||
| DifyStdOutCallbackHandler | |||||
| from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException | from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException | ||||
| from core.llm.error import LLMBadRequestError | |||||
| from core.llm.fake import FakeLLM | |||||
| from core.llm.llm_builder import LLMBuilder | |||||
| from core.llm.streamable_chat_open_ai import StreamableChatOpenAI | |||||
| from core.llm.streamable_open_ai import StreamableOpenAI | |||||
| from core.model_providers.error import LLMBadRequestError | |||||
| from core.memory.read_only_conversation_token_db_buffer_shared_memory import \ | from core.memory.read_only_conversation_token_db_buffer_shared_memory import \ | ||||
| ReadOnlyConversationTokenDBBufferSharedMemory | ReadOnlyConversationTokenDBBufferSharedMemory | ||||
| from core.model_providers.model_factory import ModelFactory | |||||
| from core.model_providers.models.entity.message import PromptMessage, to_prompt_messages | |||||
| from core.model_providers.models.llm.base import BaseLLM | |||||
| from core.orchestrator_rule_parser import OrchestratorRuleParser | from core.orchestrator_rule_parser import OrchestratorRuleParser | ||||
| from core.prompt.prompt_builder import PromptBuilder | from core.prompt.prompt_builder import PromptBuilder | ||||
| from core.prompt.prompt_template import JinjaPromptTemplate | from core.prompt.prompt_template import JinjaPromptTemplate | ||||
| inputs = conversation.inputs | inputs = conversation.inputs | ||||
| rest_tokens_for_context_and_memory = cls.get_validate_rest_tokens( | |||||
| mode=app.mode, | |||||
| final_model_instance = ModelFactory.get_text_generation_model_from_model_config( | |||||
| tenant_id=app.tenant_id, | tenant_id=app.tenant_id, | ||||
| app_model_config=app_model_config, | |||||
| query=query, | |||||
| inputs=inputs | |||||
| model_config=app_model_config.model_dict, | |||||
| streaming=streaming | |||||
| ) | ) | ||||
| conversation_message_task = ConversationMessageTask( | conversation_message_task = ConversationMessageTask( | ||||
| is_override=is_override, | is_override=is_override, | ||||
| inputs=inputs, | inputs=inputs, | ||||
| query=query, | query=query, | ||||
| streaming=streaming | |||||
| streaming=streaming, | |||||
| model_instance=final_model_instance | |||||
| ) | ) | ||||
| chain_callback = MainChainGatherCallbackHandler(conversation_message_task) | |||||
| rest_tokens_for_context_and_memory = cls.get_validate_rest_tokens( | |||||
| mode=app.mode, | |||||
| model_instance=final_model_instance, | |||||
| app_model_config=app_model_config, | |||||
| query=query, | |||||
| inputs=inputs | |||||
| ) | |||||
| # init orchestrator rule parser | # init orchestrator rule parser | ||||
| orchestrator_rule_parser = OrchestratorRuleParser( | orchestrator_rule_parser = OrchestratorRuleParser( | ||||
| ) | ) | ||||
| # parse sensitive_word_avoidance_chain | # parse sensitive_word_avoidance_chain | ||||
| chain_callback = MainChainGatherCallbackHandler(conversation_message_task) | |||||
| sensitive_word_avoidance_chain = orchestrator_rule_parser.to_sensitive_word_avoidance_chain([chain_callback]) | sensitive_word_avoidance_chain = orchestrator_rule_parser.to_sensitive_word_avoidance_chain([chain_callback]) | ||||
| if sensitive_word_avoidance_chain: | if sensitive_word_avoidance_chain: | ||||
| query = sensitive_word_avoidance_chain.run(query) | query = sensitive_word_avoidance_chain.run(query) | ||||
| # run the final llm | # run the final llm | ||||
| try: | try: | ||||
| cls.run_final_llm( | cls.run_final_llm( | ||||
| tenant_id=app.tenant_id, | |||||
| model_instance=final_model_instance, | |||||
| mode=app.mode, | mode=app.mode, | ||||
| app_model_config=app_model_config, | app_model_config=app_model_config, | ||||
| query=query, | query=query, | ||||
| inputs=inputs, | inputs=inputs, | ||||
| agent_execute_result=agent_execute_result, | agent_execute_result=agent_execute_result, | ||||
| conversation_message_task=conversation_message_task, | conversation_message_task=conversation_message_task, | ||||
| memory=memory, | |||||
| streaming=streaming | |||||
| memory=memory | |||||
| ) | ) | ||||
| except ConversationTaskStoppedException: | except ConversationTaskStoppedException: | ||||
| return | return | ||||
| return | return | ||||
| @classmethod | @classmethod | ||||
| def run_final_llm(cls, tenant_id: str, mode: str, app_model_config: AppModelConfig, query: str, inputs: dict, | |||||
| def run_final_llm(cls, model_instance: BaseLLM, mode: str, app_model_config: AppModelConfig, query: str, inputs: dict, | |||||
| agent_execute_result: Optional[AgentExecuteResult], | agent_execute_result: Optional[AgentExecuteResult], | ||||
| conversation_message_task: ConversationMessageTask, | conversation_message_task: ConversationMessageTask, | ||||
| memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory], streaming: bool): | |||||
| memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory]): | |||||
| # When no extra pre prompt is specified, | # When no extra pre prompt is specified, | ||||
| # the output of the agent can be used directly as the main output content without calling LLM again | # the output of the agent can be used directly as the main output content without calling LLM again | ||||
| fake_response = None | |||||
| if not app_model_config.pre_prompt and agent_execute_result and agent_execute_result.output \ | if not app_model_config.pre_prompt and agent_execute_result and agent_execute_result.output \ | ||||
| and agent_execute_result.strategy != PlanningStrategy.ROUTER: | and agent_execute_result.strategy != PlanningStrategy.ROUTER: | ||||
| final_llm = FakeLLM(response=agent_execute_result.output, | |||||
| origin_llm=agent_execute_result.configuration.llm, | |||||
| streaming=streaming) | |||||
| final_llm.callbacks = cls.get_llm_callbacks(final_llm, streaming, conversation_message_task) | |||||
| response = final_llm.generate([[HumanMessage(content=query)]]) | |||||
| return response | |||||
| final_llm = LLMBuilder.to_llm_from_model( | |||||
| tenant_id=tenant_id, | |||||
| model=app_model_config.model_dict, | |||||
| streaming=streaming | |||||
| ) | |||||
| fake_response = agent_execute_result.output | |||||
| # get llm prompt | # get llm prompt | ||||
| prompt, stop_words = cls.get_main_llm_prompt( | |||||
| prompt_messages, stop_words = cls.get_main_llm_prompt( | |||||
| mode=mode, | mode=mode, | ||||
| llm=final_llm, | |||||
| model=app_model_config.model_dict, | model=app_model_config.model_dict, | ||||
| pre_prompt=app_model_config.pre_prompt, | pre_prompt=app_model_config.pre_prompt, | ||||
| query=query, | query=query, | ||||
| memory=memory | memory=memory | ||||
| ) | ) | ||||
| final_llm.callbacks = cls.get_llm_callbacks(final_llm, streaming, conversation_message_task) | |||||
| cls.recale_llm_max_tokens( | cls.recale_llm_max_tokens( | ||||
| final_llm=final_llm, | |||||
| model=app_model_config.model_dict, | |||||
| prompt=prompt, | |||||
| mode=mode | |||||
| model_instance=model_instance, | |||||
| prompt_messages=prompt_messages, | |||||
| ) | ) | ||||
| response = final_llm.generate([prompt], stop_words) | |||||
| response = model_instance.run( | |||||
| messages=prompt_messages, | |||||
| stop=stop_words, | |||||
| callbacks=[LLMCallbackHandler(model_instance, conversation_message_task)], | |||||
| fake_response=fake_response | |||||
| ) | |||||
| return response | return response | ||||
| @classmethod | @classmethod | ||||
| def get_main_llm_prompt(cls, mode: str, llm: BaseLanguageModel, model: dict, | |||||
| def get_main_llm_prompt(cls, mode: str, model: dict, | |||||
| pre_prompt: str, query: str, inputs: dict, | pre_prompt: str, query: str, inputs: dict, | ||||
| agent_execute_result: Optional[AgentExecuteResult], | agent_execute_result: Optional[AgentExecuteResult], | ||||
| memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory]) -> \ | memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory]) -> \ | ||||
| Tuple[Union[str | List[BaseMessage]], Optional[List[str]]]: | |||||
| Tuple[List[PromptMessage], Optional[List[str]]]: | |||||
| if mode == 'completion': | if mode == 'completion': | ||||
| prompt_template = JinjaPromptTemplate.from_template( | prompt_template = JinjaPromptTemplate.from_template( | ||||
| template=("""Use the following context as your learned knowledge, inside <context></context> XML tags. | template=("""Use the following context as your learned knowledge, inside <context></context> XML tags. | ||||
| **prompt_inputs | **prompt_inputs | ||||
| ) | ) | ||||
| if isinstance(llm, BaseChatModel): | |||||
| # use chat llm as completion model | |||||
| return [HumanMessage(content=prompt_content)], None | |||||
| else: | |||||
| return prompt_content, None | |||||
| return [PromptMessage(content=prompt_content)], None | |||||
| else: | else: | ||||
| messages: List[BaseMessage] = [] | messages: List[BaseMessage] = [] | ||||
| inputs=human_inputs | inputs=human_inputs | ||||
| ) | ) | ||||
| curr_message_tokens = memory.llm.get_num_tokens_from_messages([tmp_human_message]) | |||||
| model_name = model['name'] | |||||
| max_tokens = model.get("completion_params").get('max_tokens') | |||||
| rest_tokens = llm_constant.max_context_token_length[model_name] \ | |||||
| - max_tokens - curr_message_tokens | |||||
| rest_tokens = max(rest_tokens, 0) | |||||
| if memory.model_instance.model_rules.max_tokens.max: | |||||
| curr_message_tokens = memory.model_instance.get_num_tokens(to_prompt_messages([tmp_human_message])) | |||||
| max_tokens = model.get("completion_params").get('max_tokens') | |||||
| rest_tokens = memory.model_instance.model_rules.max_tokens.max - max_tokens - curr_message_tokens | |||||
| rest_tokens = max(rest_tokens, 0) | |||||
| else: | |||||
| rest_tokens = 2000 | |||||
| histories = cls.get_history_messages_from_memory(memory, rest_tokens) | histories = cls.get_history_messages_from_memory(memory, rest_tokens) | ||||
| human_message_prompt += "\n\n" if human_message_prompt else "" | human_message_prompt += "\n\n" if human_message_prompt else "" | ||||
| human_message_prompt += "Here is the chat histories between human and assistant, " \ | human_message_prompt += "Here is the chat histories between human and assistant, " \ | ||||
| for message in messages: | for message in messages: | ||||
| message.content = re.sub(r'<\|.*?\|>', '', message.content) | message.content = re.sub(r'<\|.*?\|>', '', message.content) | ||||
| return messages, ['\nHuman:', '</histories>'] | |||||
| @classmethod | |||||
| def get_llm_callbacks(cls, llm: BaseLanguageModel, | |||||
| streaming: bool, | |||||
| conversation_message_task: ConversationMessageTask) -> List[BaseCallbackHandler]: | |||||
| llm_callback_handler = LLMCallbackHandler(llm, conversation_message_task) | |||||
| if streaming: | |||||
| return [llm_callback_handler, DifyStreamingStdOutCallbackHandler()] | |||||
| else: | |||||
| return [llm_callback_handler, DifyStdOutCallbackHandler()] | |||||
| return to_prompt_messages(messages), ['\nHuman:', '</histories>'] | |||||
| @classmethod | @classmethod | ||||
| def get_history_messages_from_memory(cls, memory: ReadOnlyConversationTokenDBBufferSharedMemory, | def get_history_messages_from_memory(cls, memory: ReadOnlyConversationTokenDBBufferSharedMemory, | ||||
| conversation: Conversation, | conversation: Conversation, | ||||
| **kwargs) -> ReadOnlyConversationTokenDBBufferSharedMemory: | **kwargs) -> ReadOnlyConversationTokenDBBufferSharedMemory: | ||||
| # only for calc token in memory | # only for calc token in memory | ||||
| memory_llm = LLMBuilder.to_llm_from_model( | |||||
| memory_model_instance = ModelFactory.get_text_generation_model_from_model_config( | |||||
| tenant_id=tenant_id, | tenant_id=tenant_id, | ||||
| model=app_model_config.model_dict | |||||
| model_config=app_model_config.model_dict | |||||
| ) | ) | ||||
| # use llm config from conversation | # use llm config from conversation | ||||
| memory = ReadOnlyConversationTokenDBBufferSharedMemory( | memory = ReadOnlyConversationTokenDBBufferSharedMemory( | ||||
| conversation=conversation, | conversation=conversation, | ||||
| llm=memory_llm, | |||||
| model_instance=memory_model_instance, | |||||
| max_token_limit=kwargs.get("max_token_limit", 2048), | max_token_limit=kwargs.get("max_token_limit", 2048), | ||||
| memory_key=kwargs.get("memory_key", "chat_history"), | memory_key=kwargs.get("memory_key", "chat_history"), | ||||
| return_messages=kwargs.get("return_messages", True), | return_messages=kwargs.get("return_messages", True), | ||||
| return memory | return memory | ||||
| @classmethod | @classmethod | ||||
| def get_validate_rest_tokens(cls, mode: str, tenant_id: str, app_model_config: AppModelConfig, | |||||
| def get_validate_rest_tokens(cls, mode: str, model_instance: BaseLLM, app_model_config: AppModelConfig, | |||||
| query: str, inputs: dict) -> int: | query: str, inputs: dict) -> int: | ||||
| llm = LLMBuilder.to_llm_from_model( | |||||
| tenant_id=tenant_id, | |||||
| model=app_model_config.model_dict | |||||
| ) | |||||
| model_limited_tokens = model_instance.model_rules.max_tokens.max | |||||
| max_tokens = model_instance.get_model_kwargs().max_tokens | |||||
| model_name = app_model_config.model_dict.get("name") | |||||
| model_limited_tokens = llm_constant.max_context_token_length[model_name] | |||||
| max_tokens = app_model_config.model_dict.get("completion_params").get('max_tokens') | |||||
| if model_limited_tokens is None: | |||||
| return -1 | |||||
| if max_tokens is None: | |||||
| max_tokens = 0 | |||||
| # get prompt without memory and context | # get prompt without memory and context | ||||
| prompt, _ = cls.get_main_llm_prompt( | |||||
| prompt_messages, _ = cls.get_main_llm_prompt( | |||||
| mode=mode, | mode=mode, | ||||
| llm=llm, | |||||
| model=app_model_config.model_dict, | model=app_model_config.model_dict, | ||||
| pre_prompt=app_model_config.pre_prompt, | pre_prompt=app_model_config.pre_prompt, | ||||
| query=query, | query=query, | ||||
| memory=None | memory=None | ||||
| ) | ) | ||||
| prompt_tokens = llm.get_num_tokens(prompt) if isinstance(prompt, str) \ | |||||
| else llm.get_num_tokens_from_messages(prompt) | |||||
| prompt_tokens = model_instance.get_num_tokens(prompt_messages) | |||||
| rest_tokens = model_limited_tokens - max_tokens - prompt_tokens | rest_tokens = model_limited_tokens - max_tokens - prompt_tokens | ||||
| if rest_tokens < 0: | if rest_tokens < 0: | ||||
| raise LLMBadRequestError("Query or prefix prompt is too long, you can reduce the prefix prompt, " | raise LLMBadRequestError("Query or prefix prompt is too long, you can reduce the prefix prompt, " | ||||
| return rest_tokens | return rest_tokens | ||||
| @classmethod | @classmethod | ||||
| def recale_llm_max_tokens(cls, final_llm: BaseLanguageModel, model: dict, | |||||
| prompt: Union[str, List[BaseMessage]], mode: str): | |||||
| def recale_llm_max_tokens(cls, model_instance: BaseLLM, prompt_messages: List[PromptMessage]): | |||||
| # recalc max_tokens if sum(prompt_token + max_tokens) over model token limit | # recalc max_tokens if sum(prompt_token + max_tokens) over model token limit | ||||
| model_name = model.get("name") | |||||
| model_limited_tokens = llm_constant.max_context_token_length[model_name] | |||||
| max_tokens = model.get("completion_params").get('max_tokens') | |||||
| model_limited_tokens = model_instance.model_rules.max_tokens.max | |||||
| max_tokens = model_instance.get_model_kwargs().max_tokens | |||||
| if mode == 'completion' and isinstance(final_llm, BaseLLM): | |||||
| prompt_tokens = final_llm.get_num_tokens(prompt) | |||||
| else: | |||||
| prompt_tokens = final_llm.get_num_tokens_from_messages(prompt) | |||||
| if model_limited_tokens is None: | |||||
| return | |||||
| if max_tokens is None: | |||||
| max_tokens = 0 | |||||
| prompt_tokens = model_instance.get_num_tokens(prompt_messages) | |||||
| if prompt_tokens + max_tokens > model_limited_tokens: | if prompt_tokens + max_tokens > model_limited_tokens: | ||||
| max_tokens = max(model_limited_tokens - prompt_tokens, 16) | max_tokens = max(model_limited_tokens - prompt_tokens, 16) | ||||
| final_llm.max_tokens = max_tokens | |||||
| # update model instance max tokens | |||||
| model_kwargs = model_instance.get_model_kwargs() | |||||
| model_kwargs.max_tokens = max_tokens | |||||
| model_instance.set_model_kwargs(model_kwargs) | |||||
| @classmethod | @classmethod | ||||
| def generate_more_like_this(cls, task_id: str, app: App, message: Message, pre_prompt: str, | def generate_more_like_this(cls, task_id: str, app: App, message: Message, pre_prompt: str, | ||||
| app_model_config: AppModelConfig, user: Account, streaming: bool): | app_model_config: AppModelConfig, user: Account, streaming: bool): | ||||
| llm = LLMBuilder.to_llm_from_model( | |||||
| final_model_instance = ModelFactory.get_text_generation_model_from_model_config( | |||||
| tenant_id=app.tenant_id, | tenant_id=app.tenant_id, | ||||
| model=app_model_config.model_dict, | |||||
| model_config=app_model_config.model_dict, | |||||
| streaming=streaming | streaming=streaming | ||||
| ) | ) | ||||
| # get llm prompt | # get llm prompt | ||||
| original_prompt, _ = cls.get_main_llm_prompt( | |||||
| old_prompt_messages, _ = cls.get_main_llm_prompt( | |||||
| mode="completion", | mode="completion", | ||||
| llm=llm, | |||||
| model=app_model_config.model_dict, | model=app_model_config.model_dict, | ||||
| pre_prompt=pre_prompt, | pre_prompt=pre_prompt, | ||||
| query=message.query, | query=message.query, | ||||
| original_completion = message.answer.strip() | original_completion = message.answer.strip() | ||||
| prompt = MORE_LIKE_THIS_GENERATE_PROMPT | prompt = MORE_LIKE_THIS_GENERATE_PROMPT | ||||
| prompt = prompt.format(prompt=original_prompt, original_completion=original_completion) | |||||
| prompt = prompt.format(prompt=old_prompt_messages[0].content, original_completion=original_completion) | |||||
| if isinstance(llm, BaseChatModel): | |||||
| prompt = [HumanMessage(content=prompt)] | |||||
| prompt_messages = [PromptMessage(content=prompt)] | |||||
| conversation_message_task = ConversationMessageTask( | conversation_message_task = ConversationMessageTask( | ||||
| task_id=task_id, | task_id=task_id, | ||||
| inputs=message.inputs, | inputs=message.inputs, | ||||
| query=message.query, | query=message.query, | ||||
| is_override=True if message.override_model_configs else False, | is_override=True if message.override_model_configs else False, | ||||
| streaming=streaming | |||||
| streaming=streaming, | |||||
| model_instance=final_model_instance | |||||
| ) | ) | ||||
| llm.callbacks = cls.get_llm_callbacks(llm, streaming, conversation_message_task) | |||||
| cls.recale_llm_max_tokens( | cls.recale_llm_max_tokens( | ||||
| final_llm=llm, | |||||
| model=app_model_config.model_dict, | |||||
| prompt=prompt, | |||||
| mode='completion' | |||||
| model_instance=final_model_instance, | |||||
| prompt_messages=prompt_messages | |||||
| ) | ) | ||||
| llm.generate([prompt]) | |||||
| final_model_instance.run( | |||||
| messages=prompt_messages, | |||||
| callbacks=[LLMCallbackHandler(final_model_instance, conversation_message_task)] | |||||
| ) |
| from _decimal import Decimal | |||||
| models = { | |||||
| 'claude-instant-1': 'anthropic', # 100,000 tokens | |||||
| 'claude-2': 'anthropic', # 100,000 tokens | |||||
| 'gpt-4': 'openai', # 8,192 tokens | |||||
| 'gpt-4-32k': 'openai', # 32,768 tokens | |||||
| 'gpt-3.5-turbo': 'openai', # 4,096 tokens | |||||
| 'gpt-3.5-turbo-16k': 'openai', # 16384 tokens | |||||
| 'text-davinci-003': 'openai', # 4,097 tokens | |||||
| 'text-davinci-002': 'openai', # 4,097 tokens | |||||
| 'text-curie-001': 'openai', # 2,049 tokens | |||||
| 'text-babbage-001': 'openai', # 2,049 tokens | |||||
| 'text-ada-001': 'openai', # 2,049 tokens | |||||
| 'text-embedding-ada-002': 'openai', # 8191 tokens, 1536 dimensions | |||||
| 'whisper-1': 'openai' | |||||
| } | |||||
| max_context_token_length = { | |||||
| 'claude-instant-1': 100000, | |||||
| 'claude-2': 100000, | |||||
| 'gpt-4': 8192, | |||||
| 'gpt-4-32k': 32768, | |||||
| 'gpt-3.5-turbo': 4096, | |||||
| 'gpt-3.5-turbo-16k': 16384, | |||||
| 'text-davinci-003': 4097, | |||||
| 'text-davinci-002': 4097, | |||||
| 'text-curie-001': 2049, | |||||
| 'text-babbage-001': 2049, | |||||
| 'text-ada-001': 2049, | |||||
| 'text-embedding-ada-002': 8191, | |||||
| } | |||||
| models_by_mode = { | |||||
| 'chat': [ | |||||
| 'claude-instant-1', # 100,000 tokens | |||||
| 'claude-2', # 100,000 tokens | |||||
| 'gpt-4', # 8,192 tokens | |||||
| 'gpt-4-32k', # 32,768 tokens | |||||
| 'gpt-3.5-turbo', # 4,096 tokens | |||||
| 'gpt-3.5-turbo-16k', # 16,384 tokens | |||||
| ], | |||||
| 'completion': [ | |||||
| 'claude-instant-1', # 100,000 tokens | |||||
| 'claude-2', # 100,000 tokens | |||||
| 'gpt-4', # 8,192 tokens | |||||
| 'gpt-4-32k', # 32,768 tokens | |||||
| 'gpt-3.5-turbo', # 4,096 tokens | |||||
| 'gpt-3.5-turbo-16k', # 16,384 tokens | |||||
| 'text-davinci-003', # 4,097 tokens | |||||
| 'text-davinci-002' # 4,097 tokens | |||||
| 'text-curie-001', # 2,049 tokens | |||||
| 'text-babbage-001', # 2,049 tokens | |||||
| 'text-ada-001' # 2,049 tokens | |||||
| ], | |||||
| 'embedding': [ | |||||
| 'text-embedding-ada-002' # 8191 tokens, 1536 dimensions | |||||
| ] | |||||
| } | |||||
| model_currency = 'USD' | |||||
| model_prices = { | |||||
| 'claude-instant-1': { | |||||
| 'prompt': Decimal('0.00163'), | |||||
| 'completion': Decimal('0.00551'), | |||||
| }, | |||||
| 'claude-2': { | |||||
| 'prompt': Decimal('0.01102'), | |||||
| 'completion': Decimal('0.03268'), | |||||
| }, | |||||
| 'gpt-4': { | |||||
| 'prompt': Decimal('0.03'), | |||||
| 'completion': Decimal('0.06'), | |||||
| }, | |||||
| 'gpt-4-32k': { | |||||
| 'prompt': Decimal('0.06'), | |||||
| 'completion': Decimal('0.12') | |||||
| }, | |||||
| 'gpt-3.5-turbo': { | |||||
| 'prompt': Decimal('0.0015'), | |||||
| 'completion': Decimal('0.002') | |||||
| }, | |||||
| 'gpt-3.5-turbo-16k': { | |||||
| 'prompt': Decimal('0.003'), | |||||
| 'completion': Decimal('0.004') | |||||
| }, | |||||
| 'text-davinci-003': { | |||||
| 'prompt': Decimal('0.02'), | |||||
| 'completion': Decimal('0.02') | |||||
| }, | |||||
| 'text-curie-001': { | |||||
| 'prompt': Decimal('0.002'), | |||||
| 'completion': Decimal('0.002') | |||||
| }, | |||||
| 'text-babbage-001': { | |||||
| 'prompt': Decimal('0.0005'), | |||||
| 'completion': Decimal('0.0005') | |||||
| }, | |||||
| 'text-ada-001': { | |||||
| 'prompt': Decimal('0.0004'), | |||||
| 'completion': Decimal('0.0004') | |||||
| }, | |||||
| 'text-embedding-ada-002': { | |||||
| 'usage': Decimal('0.0001'), | |||||
| } | |||||
| } | |||||
| agent_model_name = 'text-davinci-003' |
| from core.callback_handler.entity.dataset_query import DatasetQueryObj | from core.callback_handler.entity.dataset_query import DatasetQueryObj | ||||
| from core.callback_handler.entity.llm_message import LLMMessage | from core.callback_handler.entity.llm_message import LLMMessage | ||||
| from core.callback_handler.entity.chain_result import ChainResult | from core.callback_handler.entity.chain_result import ChainResult | ||||
| from core.constant import llm_constant | |||||
| from core.llm.llm_builder import LLMBuilder | |||||
| from core.llm.provider.llm_provider_service import LLMProviderService | |||||
| from core.model_providers.model_factory import ModelFactory | |||||
| from core.model_providers.models.entity.message import to_prompt_messages, MessageType | |||||
| from core.model_providers.models.llm.base import BaseLLM | |||||
| from core.prompt.prompt_builder import PromptBuilder | from core.prompt.prompt_builder import PromptBuilder | ||||
| from core.prompt.prompt_template import JinjaPromptTemplate | from core.prompt.prompt_template import JinjaPromptTemplate | ||||
| from events.message_event import message_was_created | from events.message_event import message_was_created | ||||
| from extensions.ext_redis import redis_client | from extensions.ext_redis import redis_client | ||||
| from models.dataset import DatasetQuery | from models.dataset import DatasetQuery | ||||
| from models.model import AppModelConfig, Conversation, Account, Message, EndUser, App, MessageAgentThought, MessageChain | from models.model import AppModelConfig, Conversation, Account, Message, EndUser, App, MessageAgentThought, MessageChain | ||||
| from models.provider import ProviderType, Provider | |||||
| class ConversationMessageTask: | class ConversationMessageTask: | ||||
| def __init__(self, task_id: str, app: App, app_model_config: AppModelConfig, user: Account, | def __init__(self, task_id: str, app: App, app_model_config: AppModelConfig, user: Account, | ||||
| inputs: dict, query: str, streaming: bool, | |||||
| inputs: dict, query: str, streaming: bool, model_instance: BaseLLM, | |||||
| conversation: Optional[Conversation] = None, is_override: bool = False): | conversation: Optional[Conversation] = None, is_override: bool = False): | ||||
| self.task_id = task_id | self.task_id = task_id | ||||
| self.conversation = conversation | self.conversation = conversation | ||||
| self.is_new_conversation = False | self.is_new_conversation = False | ||||
| self.model_instance = model_instance | |||||
| self.message = None | self.message = None | ||||
| self.model_dict = self.app_model_config.model_dict | self.model_dict = self.app_model_config.model_dict | ||||
| self.provider_name = self.model_dict.get('provider') | |||||
| self.model_name = self.model_dict.get('name') | self.model_name = self.model_dict.get('name') | ||||
| self.mode = app.mode | self.mode = app.mode | ||||
| ) | ) | ||||
| def init(self): | def init(self): | ||||
| provider_name = LLMBuilder.get_default_provider(self.app.tenant_id, self.model_name) | |||||
| self.model_dict['provider'] = provider_name | |||||
| override_model_configs = None | override_model_configs = None | ||||
| if self.is_override: | if self.is_override: | ||||
| override_model_configs = { | override_model_configs = { | ||||
| if self.app_model_config.pre_prompt: | if self.app_model_config.pre_prompt: | ||||
| system_message = PromptBuilder.to_system_message(self.app_model_config.pre_prompt, self.inputs) | system_message = PromptBuilder.to_system_message(self.app_model_config.pre_prompt, self.inputs) | ||||
| system_instruction = system_message.content | system_instruction = system_message.content | ||||
| llm = LLMBuilder.to_llm(self.tenant_id, self.model_name) | |||||
| system_instruction_tokens = llm.get_num_tokens_from_messages([system_message]) | |||||
| model_instance = ModelFactory.get_text_generation_model( | |||||
| tenant_id=self.tenant_id, | |||||
| model_provider_name=self.provider_name, | |||||
| model_name=self.model_name | |||||
| ) | |||||
| system_instruction_tokens = model_instance.get_num_tokens(to_prompt_messages([system_message])) | |||||
| if not self.conversation: | if not self.conversation: | ||||
| self.is_new_conversation = True | self.is_new_conversation = True | ||||
| self.conversation = Conversation( | self.conversation = Conversation( | ||||
| app_id=self.app_model_config.app_id, | app_id=self.app_model_config.app_id, | ||||
| app_model_config_id=self.app_model_config.id, | app_model_config_id=self.app_model_config.id, | ||||
| model_provider=self.model_dict.get('provider'), | |||||
| model_provider=self.provider_name, | |||||
| model_id=self.model_name, | model_id=self.model_name, | ||||
| override_model_configs=json.dumps(override_model_configs) if override_model_configs else None, | override_model_configs=json.dumps(override_model_configs) if override_model_configs else None, | ||||
| mode=self.mode, | mode=self.mode, | ||||
| self.message = Message( | self.message = Message( | ||||
| app_id=self.app_model_config.app_id, | app_id=self.app_model_config.app_id, | ||||
| model_provider=self.model_dict.get('provider'), | |||||
| model_provider=self.provider_name, | |||||
| model_id=self.model_name, | model_id=self.model_name, | ||||
| override_model_configs=json.dumps(override_model_configs) if override_model_configs else None, | override_model_configs=json.dumps(override_model_configs) if override_model_configs else None, | ||||
| conversation_id=self.conversation.id, | conversation_id=self.conversation.id, | ||||
| answer_unit_price=0, | answer_unit_price=0, | ||||
| provider_response_latency=0, | provider_response_latency=0, | ||||
| total_price=0, | total_price=0, | ||||
| currency=llm_constant.model_currency, | |||||
| currency=self.model_instance.get_currency(), | |||||
| from_source=('console' if isinstance(self.user, Account) else 'api'), | from_source=('console' if isinstance(self.user, Account) else 'api'), | ||||
| from_end_user_id=(self.user.id if isinstance(self.user, EndUser) else None), | from_end_user_id=(self.user.id if isinstance(self.user, EndUser) else None), | ||||
| from_account_id=(self.user.id if isinstance(self.user, Account) else None), | from_account_id=(self.user.id if isinstance(self.user, Account) else None), | ||||
| self._pub_handler.pub_text(text) | self._pub_handler.pub_text(text) | ||||
| def save_message(self, llm_message: LLMMessage, by_stopped: bool = False): | def save_message(self, llm_message: LLMMessage, by_stopped: bool = False): | ||||
| model_name = self.app_model_config.model_dict.get('name') | |||||
| message_tokens = llm_message.prompt_tokens | message_tokens = llm_message.prompt_tokens | ||||
| answer_tokens = llm_message.completion_tokens | answer_tokens = llm_message.completion_tokens | ||||
| message_unit_price = llm_constant.model_prices[model_name]['prompt'] | |||||
| answer_unit_price = llm_constant.model_prices[model_name]['completion'] | |||||
| message_unit_price = self.model_instance.get_token_price(1, MessageType.HUMAN) | |||||
| answer_unit_price = self.model_instance.get_token_price(1, MessageType.ASSISTANT) | |||||
| total_price = self.calc_total_price(message_tokens, message_unit_price, answer_tokens, answer_unit_price) | total_price = self.calc_total_price(message_tokens, message_unit_price, answer_tokens, answer_unit_price) | ||||
| self.message.provider_response_latency = llm_message.latency | self.message.provider_response_latency = llm_message.latency | ||||
| self.message.total_price = total_price | self.message.total_price = total_price | ||||
| self.update_provider_quota() | |||||
| db.session.commit() | db.session.commit() | ||||
| message_was_created.send( | message_was_created.send( | ||||
| if not by_stopped: | if not by_stopped: | ||||
| self.end() | self.end() | ||||
| def update_provider_quota(self): | |||||
| llm_provider_service = LLMProviderService( | |||||
| tenant_id=self.app.tenant_id, | |||||
| provider_name=self.message.model_provider, | |||||
| ) | |||||
| provider = llm_provider_service.get_provider_db_record() | |||||
| if provider and provider.provider_type == ProviderType.SYSTEM.value: | |||||
| db.session.query(Provider).filter( | |||||
| Provider.tenant_id == self.app.tenant_id, | |||||
| Provider.provider_name == provider.provider_name, | |||||
| Provider.quota_limit > Provider.quota_used | |||||
| ).update({'quota_used': Provider.quota_used + 1}) | |||||
| def init_chain(self, chain_result: ChainResult): | def init_chain(self, chain_result: ChainResult): | ||||
| message_chain = MessageChain( | message_chain = MessageChain( | ||||
| message_id=self.message.id, | message_id=self.message.id, | ||||
| return message_agent_thought | return message_agent_thought | ||||
| def on_agent_end(self, message_agent_thought: MessageAgentThought, agent_model_name: str, | |||||
| def on_agent_end(self, message_agent_thought: MessageAgentThought, agent_model_instant: BaseLLM, | |||||
| agent_loop: AgentLoop): | agent_loop: AgentLoop): | ||||
| agent_message_unit_price = llm_constant.model_prices[agent_model_name]['prompt'] | |||||
| agent_answer_unit_price = llm_constant.model_prices[agent_model_name]['completion'] | |||||
| agent_message_unit_price = agent_model_instant.get_token_price(1, MessageType.HUMAN) | |||||
| agent_answer_unit_price = agent_model_instant.get_token_price(1, MessageType.ASSISTANT) | |||||
| loop_message_tokens = agent_loop.prompt_tokens | loop_message_tokens = agent_loop.prompt_tokens | ||||
| loop_answer_tokens = agent_loop.completion_tokens | loop_answer_tokens = agent_loop.completion_tokens | ||||
| message_agent_thought.latency = agent_loop.latency | message_agent_thought.latency = agent_loop.latency | ||||
| message_agent_thought.tokens = agent_loop.prompt_tokens + agent_loop.completion_tokens | message_agent_thought.tokens = agent_loop.prompt_tokens + agent_loop.completion_tokens | ||||
| message_agent_thought.total_price = loop_total_price | message_agent_thought.total_price = loop_total_price | ||||
| message_agent_thought.currency = llm_constant.model_currency | |||||
| message_agent_thought.currency = agent_model_instant.get_currency() | |||||
| db.session.flush() | db.session.flush() | ||||
| def on_dataset_query_end(self, dataset_query_obj: DatasetQueryObj): | def on_dataset_query_end(self, dataset_query_obj: DatasetQueryObj): |
| from langchain.schema import Document | from langchain.schema import Document | ||||
| from sqlalchemy import func | from sqlalchemy import func | ||||
| from core.llm.token_calculator import TokenCalculator | |||||
| from core.model_providers.model_factory import ModelFactory | |||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| from models.dataset import Dataset, DocumentSegment | from models.dataset import Dataset, DocumentSegment | ||||
| self, | self, | ||||
| dataset: Dataset, | dataset: Dataset, | ||||
| user_id: str, | user_id: str, | ||||
| embedding_model_name: str, | |||||
| document_id: Optional[str] = None, | document_id: Optional[str] = None, | ||||
| ): | ): | ||||
| self._dataset = dataset | self._dataset = dataset | ||||
| self._user_id = user_id | self._user_id = user_id | ||||
| self._embedding_model_name = embedding_model_name | |||||
| self._document_id = document_id | self._document_id = document_id | ||||
| @classmethod | @classmethod | ||||
| def user_id(self) -> Any: | def user_id(self) -> Any: | ||||
| return self._user_id | return self._user_id | ||||
| @property | |||||
| def embedding_model_name(self) -> Any: | |||||
| return self._embedding_model_name | |||||
| @property | @property | ||||
| def docs(self) -> Dict[str, Document]: | def docs(self) -> Dict[str, Document]: | ||||
| document_segments = db.session.query(DocumentSegment).filter( | document_segments = db.session.query(DocumentSegment).filter( | ||||
| if max_position is None: | if max_position is None: | ||||
| max_position = 0 | max_position = 0 | ||||
| embedding_model = ModelFactory.get_embedding_model( | |||||
| tenant_id=self._dataset.tenant_id | |||||
| ) | |||||
| for doc in docs: | for doc in docs: | ||||
| if not isinstance(doc, Document): | if not isinstance(doc, Document): | ||||
| raise ValueError("doc must be a Document") | raise ValueError("doc must be a Document") | ||||
| ) | ) | ||||
| # calc embedding use tokens | # calc embedding use tokens | ||||
| tokens = TokenCalculator.get_num_tokens(self._embedding_model_name, doc.page_content) | |||||
| tokens = embedding_model.get_num_tokens(doc.page_content) | |||||
| if not segment_document: | if not segment_document: | ||||
| max_position += 1 | max_position += 1 |
| from langchain.embeddings.base import Embeddings | from langchain.embeddings.base import Embeddings | ||||
| from sqlalchemy.exc import IntegrityError | from sqlalchemy.exc import IntegrityError | ||||
| from core.llm.wrappers.openai_wrapper import handle_openai_exceptions | |||||
| from core.model_providers.models.embedding.base import BaseEmbedding | |||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| from libs import helper | from libs import helper | ||||
| from models.dataset import Embedding | from models.dataset import Embedding | ||||
| class CacheEmbedding(Embeddings): | class CacheEmbedding(Embeddings): | ||||
| def __init__(self, embeddings: Embeddings): | |||||
| def __init__(self, embeddings: BaseEmbedding): | |||||
| self._embeddings = embeddings | self._embeddings = embeddings | ||||
| def embed_documents(self, texts: List[str]) -> List[List[float]]: | def embed_documents(self, texts: List[str]) -> List[List[float]]: | ||||
| embedding_queue_texts = [] | embedding_queue_texts = [] | ||||
| for text in texts: | for text in texts: | ||||
| hash = helper.generate_text_hash(text) | hash = helper.generate_text_hash(text) | ||||
| embedding = db.session.query(Embedding).filter_by(hash=hash).first() | |||||
| embedding = db.session.query(Embedding).filter_by(model_name=self._embeddings.name, hash=hash).first() | |||||
| if embedding: | if embedding: | ||||
| text_embeddings.append(embedding.get_embedding()) | text_embeddings.append(embedding.get_embedding()) | ||||
| else: | else: | ||||
| embedding_queue_texts.append(text) | embedding_queue_texts.append(text) | ||||
| embedding_results = self._embeddings.embed_documents(embedding_queue_texts) | |||||
| if embedding_queue_texts: | |||||
| try: | |||||
| embedding_results = self._embeddings.client.embed_documents(embedding_queue_texts) | |||||
| except Exception as ex: | |||||
| raise self._embeddings.handle_exceptions(ex) | |||||
| i = 0 | |||||
| for text in embedding_queue_texts: | |||||
| hash = helper.generate_text_hash(text) | |||||
| i = 0 | |||||
| for text in embedding_queue_texts: | |||||
| hash = helper.generate_text_hash(text) | |||||
| try: | |||||
| embedding = Embedding(hash=hash) | |||||
| embedding.set_embedding(embedding_results[i]) | |||||
| db.session.add(embedding) | |||||
| db.session.commit() | |||||
| except IntegrityError: | |||||
| db.session.rollback() | |||||
| continue | |||||
| except: | |||||
| logging.exception('Failed to add embedding to db') | |||||
| continue | |||||
| finally: | |||||
| i += 1 | |||||
| try: | |||||
| embedding = Embedding(model_name=self._embeddings.name, hash=hash) | |||||
| embedding.set_embedding(embedding_results[i]) | |||||
| db.session.add(embedding) | |||||
| db.session.commit() | |||||
| except IntegrityError: | |||||
| db.session.rollback() | |||||
| continue | |||||
| except: | |||||
| logging.exception('Failed to add embedding to db') | |||||
| continue | |||||
| finally: | |||||
| i += 1 | |||||
| text_embeddings.extend(embedding_results) | |||||
| text_embeddings.extend(embedding_results) | |||||
| return text_embeddings | return text_embeddings | ||||
| @handle_openai_exceptions | |||||
| def embed_query(self, text: str) -> List[float]: | def embed_query(self, text: str) -> List[float]: | ||||
| """Embed query text.""" | """Embed query text.""" | ||||
| # use doc embedding cache or store if not exists | # use doc embedding cache or store if not exists | ||||
| hash = helper.generate_text_hash(text) | hash = helper.generate_text_hash(text) | ||||
| embedding = db.session.query(Embedding).filter_by(hash=hash).first() | |||||
| embedding = db.session.query(Embedding).filter_by(model_name=self._embeddings.name, hash=hash).first() | |||||
| if embedding: | if embedding: | ||||
| return embedding.get_embedding() | return embedding.get_embedding() | ||||
| embedding_results = self._embeddings.embed_query(text) | |||||
| try: | |||||
| embedding_results = self._embeddings.client.embed_query(text) | |||||
| except Exception as ex: | |||||
| raise self._embeddings.handle_exceptions(ex) | |||||
| try: | try: | ||||
| embedding = Embedding(hash=hash) | |||||
| embedding = Embedding(model_name=self._embeddings.name, hash=hash) | |||||
| embedding.set_embedding(embedding_results) | embedding.set_embedding(embedding_results) | ||||
| db.session.add(embedding) | db.session.add(embedding) | ||||
| db.session.commit() | db.session.commit() | ||||
| logging.exception('Failed to add embedding to db') | logging.exception('Failed to add embedding to db') | ||||
| return embedding_results | return embedding_results | ||||
| import logging | import logging | ||||
| from langchain import PromptTemplate | |||||
| from langchain.chat_models.base import BaseChatModel | |||||
| from langchain.schema import HumanMessage, OutputParserException, BaseMessage, SystemMessage | |||||
| from core.constant import llm_constant | |||||
| from core.llm.llm_builder import LLMBuilder | |||||
| from core.llm.streamable_open_ai import StreamableOpenAI | |||||
| from core.llm.token_calculator import TokenCalculator | |||||
| from langchain.schema import OutputParserException | |||||
| from core.model_providers.model_factory import ModelFactory | |||||
| from core.model_providers.models.entity.message import PromptMessage, MessageType | |||||
| from core.model_providers.models.entity.model_params import ModelKwargs | |||||
| from core.prompt.output_parser.rule_config_generator import RuleConfigGeneratorOutputParser | from core.prompt.output_parser.rule_config_generator import RuleConfigGeneratorOutputParser | ||||
| from core.prompt.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser | from core.prompt.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser | ||||
| from core.prompt.prompts import CONVERSATION_TITLE_PROMPT, CONVERSATION_SUMMARY_PROMPT, INTRODUCTION_GENERATE_PROMPT, \ | from core.prompt.prompts import CONVERSATION_TITLE_PROMPT, CONVERSATION_SUMMARY_PROMPT, INTRODUCTION_GENERATE_PROMPT, \ | ||||
| GENERATOR_QA_PROMPT | GENERATOR_QA_PROMPT | ||||
| # gpt-3.5-turbo works not well | |||||
| generate_base_model = 'text-davinci-003' | |||||
| class LLMGenerator: | class LLMGenerator: | ||||
| @classmethod | @classmethod | ||||
| query = query[:300] + "...[TRUNCATED]..." + query[-300:] | query = query[:300] + "...[TRUNCATED]..." + query[-300:] | ||||
| prompt = prompt.format(query=query) | prompt = prompt.format(query=query) | ||||
| llm: StreamableOpenAI = LLMBuilder.to_llm( | |||||
| model_instance = ModelFactory.get_text_generation_model( | |||||
| tenant_id=tenant_id, | tenant_id=tenant_id, | ||||
| model_name='gpt-3.5-turbo', | |||||
| max_tokens=50, | |||||
| timeout=600 | |||||
| model_kwargs=ModelKwargs( | |||||
| max_tokens=50 | |||||
| ) | |||||
| ) | ) | ||||
| if isinstance(llm, BaseChatModel): | |||||
| prompt = [HumanMessage(content=prompt)] | |||||
| response = llm.generate([prompt]) | |||||
| answer = response.generations[0][0].text | |||||
| prompts = [PromptMessage(content=prompt)] | |||||
| response = model_instance.run(prompts) | |||||
| answer = response.content | |||||
| return answer.strip() | return answer.strip() | ||||
| @classmethod | @classmethod | ||||
| def generate_conversation_summary(cls, tenant_id: str, messages): | def generate_conversation_summary(cls, tenant_id: str, messages): | ||||
| max_tokens = 200 | max_tokens = 200 | ||||
| model = 'gpt-3.5-turbo' | |||||
| model_instance = ModelFactory.get_text_generation_model( | |||||
| tenant_id=tenant_id, | |||||
| model_kwargs=ModelKwargs( | |||||
| max_tokens=max_tokens | |||||
| ) | |||||
| ) | |||||
| prompt = CONVERSATION_SUMMARY_PROMPT | prompt = CONVERSATION_SUMMARY_PROMPT | ||||
| prompt_with_empty_context = prompt.format(context='') | prompt_with_empty_context = prompt.format(context='') | ||||
| prompt_tokens = TokenCalculator.get_num_tokens(model, prompt_with_empty_context) | |||||
| rest_tokens = llm_constant.max_context_token_length[model] - prompt_tokens - max_tokens - 1 | |||||
| prompt_tokens = model_instance.get_num_tokens([PromptMessage(content=prompt_with_empty_context)]) | |||||
| max_context_token_length = model_instance.model_rules.max_tokens.max | |||||
| rest_tokens = max_context_token_length - prompt_tokens - max_tokens - 1 | |||||
| context = '' | context = '' | ||||
| for message in messages: | for message in messages: | ||||
| answer = message.answer | answer = message.answer | ||||
| message_qa_text = "\n\nHuman:" + query + "\n\nAssistant:" + answer | message_qa_text = "\n\nHuman:" + query + "\n\nAssistant:" + answer | ||||
| if rest_tokens - TokenCalculator.get_num_tokens(model, context + message_qa_text) > 0: | |||||
| if rest_tokens - model_instance.get_num_tokens([PromptMessage(content=context + message_qa_text)]) > 0: | |||||
| context += message_qa_text | context += message_qa_text | ||||
| if not context: | if not context: | ||||
| return '[message too long, no summary]' | return '[message too long, no summary]' | ||||
| prompt = prompt.format(context=context) | prompt = prompt.format(context=context) | ||||
| llm: StreamableOpenAI = LLMBuilder.to_llm( | |||||
| tenant_id=tenant_id, | |||||
| model_name=model, | |||||
| max_tokens=max_tokens | |||||
| ) | |||||
| if isinstance(llm, BaseChatModel): | |||||
| prompt = [HumanMessage(content=prompt)] | |||||
| response = llm.generate([prompt]) | |||||
| answer = response.generations[0][0].text | |||||
| prompts = [PromptMessage(content=prompt)] | |||||
| response = model_instance.run(prompts) | |||||
| answer = response.content | |||||
| return answer.strip() | return answer.strip() | ||||
| @classmethod | @classmethod | ||||
| prompt = INTRODUCTION_GENERATE_PROMPT | prompt = INTRODUCTION_GENERATE_PROMPT | ||||
| prompt = prompt.format(prompt=pre_prompt) | prompt = prompt.format(prompt=pre_prompt) | ||||
| llm: StreamableOpenAI = LLMBuilder.to_llm( | |||||
| tenant_id=tenant_id, | |||||
| model_name=generate_base_model, | |||||
| model_instance = ModelFactory.get_text_generation_model( | |||||
| tenant_id=tenant_id | |||||
| ) | ) | ||||
| if isinstance(llm, BaseChatModel): | |||||
| prompt = [HumanMessage(content=prompt)] | |||||
| response = llm.generate([prompt]) | |||||
| answer = response.generations[0][0].text | |||||
| prompts = [PromptMessage(content=prompt)] | |||||
| response = model_instance.run(prompts) | |||||
| answer = response.content | |||||
| return answer.strip() | return answer.strip() | ||||
| @classmethod | @classmethod | ||||
| _input = prompt.format_prompt(histories=histories) | _input = prompt.format_prompt(histories=histories) | ||||
| llm: StreamableOpenAI = LLMBuilder.to_llm( | |||||
| model_instance = ModelFactory.get_text_generation_model( | |||||
| tenant_id=tenant_id, | tenant_id=tenant_id, | ||||
| model_name='gpt-3.5-turbo', | |||||
| temperature=0, | |||||
| max_tokens=256 | |||||
| model_kwargs=ModelKwargs( | |||||
| max_tokens=256, | |||||
| temperature=0 | |||||
| ) | |||||
| ) | ) | ||||
| if isinstance(llm, BaseChatModel): | |||||
| query = [HumanMessage(content=_input.to_string())] | |||||
| else: | |||||
| query = _input.to_string() | |||||
| prompts = [PromptMessage(content=_input.to_string())] | |||||
| try: | try: | ||||
| output = llm(query) | |||||
| if isinstance(output, BaseMessage): | |||||
| output = output.content | |||||
| questions = output_parser.parse(output) | |||||
| output = model_instance.run(prompts) | |||||
| questions = output_parser.parse(output.content) | |||||
| except Exception: | except Exception: | ||||
| logging.exception("Error generating suggested questions after answer") | logging.exception("Error generating suggested questions after answer") | ||||
| questions = [] | questions = [] | ||||
| _input = prompt.format_prompt(audiences=audiences, hoping_to_solve=hoping_to_solve) | _input = prompt.format_prompt(audiences=audiences, hoping_to_solve=hoping_to_solve) | ||||
| llm: StreamableOpenAI = LLMBuilder.to_llm( | |||||
| model_instance = ModelFactory.get_text_generation_model( | |||||
| tenant_id=tenant_id, | tenant_id=tenant_id, | ||||
| model_name=generate_base_model, | |||||
| temperature=0, | |||||
| max_tokens=512 | |||||
| model_kwargs=ModelKwargs( | |||||
| max_tokens=512, | |||||
| temperature=0 | |||||
| ) | |||||
| ) | ) | ||||
| if isinstance(llm, BaseChatModel): | |||||
| query = [HumanMessage(content=_input.to_string())] | |||||
| else: | |||||
| query = _input.to_string() | |||||
| prompts = [PromptMessage(content=_input.to_string())] | |||||
| try: | try: | ||||
| output = llm(query) | |||||
| rule_config = output_parser.parse(output) | |||||
| output = model_instance.run(prompts) | |||||
| rule_config = output_parser.parse(output.content) | |||||
| except OutputParserException: | except OutputParserException: | ||||
| raise ValueError('Please give a valid input for intended audience or hoping to solve problems.') | raise ValueError('Please give a valid input for intended audience or hoping to solve problems.') | ||||
| except Exception: | except Exception: | ||||
| return rule_config | return rule_config | ||||
| @classmethod | @classmethod | ||||
| async def generate_qa_document(cls, llm: StreamableOpenAI, query): | |||||
| def generate_qa_document(cls, tenant_id: str, query): | |||||
| prompt = GENERATOR_QA_PROMPT | prompt = GENERATOR_QA_PROMPT | ||||
| model_instance = ModelFactory.get_text_generation_model( | |||||
| tenant_id=tenant_id, | |||||
| model_kwargs=ModelKwargs( | |||||
| max_tokens=2000 | |||||
| ) | |||||
| ) | |||||
| if isinstance(llm, BaseChatModel): | |||||
| prompt = [SystemMessage(content=prompt), HumanMessage(content=query)] | |||||
| response = llm.generate([prompt]) | |||||
| answer = response.generations[0][0].text | |||||
| return answer.strip() | |||||
| @classmethod | |||||
| def generate_qa_document_sync(cls, llm: StreamableOpenAI, query): | |||||
| prompt = GENERATOR_QA_PROMPT | |||||
| if isinstance(llm, BaseChatModel): | |||||
| prompt = [SystemMessage(content=prompt), HumanMessage(content=query)] | |||||
| prompts = [ | |||||
| PromptMessage(content=prompt, type=MessageType.SYSTEM), | |||||
| PromptMessage(content=query) | |||||
| ] | |||||
| response = llm.generate([prompt]) | |||||
| answer = response.generations[0][0].text | |||||
| response = model_instance.run(prompts) | |||||
| answer = response.content | |||||
| return answer.strip() | return answer.strip() |
| import base64 | |||||
| from extensions.ext_database import db | |||||
| from libs import rsa | |||||
| from models.account import Tenant | |||||
| def obfuscated_token(token: str): | |||||
| return token[:6] + '*' * (len(token) - 8) + token[-2:] | |||||
| def encrypt_token(tenant_id: str, token: str): | |||||
| tenant = db.session.query(Tenant).filter(Tenant.id == tenant_id).first() | |||||
| encrypted_token = rsa.encrypt(token, tenant.encrypt_public_key) | |||||
| return base64.b64encode(encrypted_token).decode() | |||||
| def decrypt_token(tenant_id: str, token: str): | |||||
| return rsa.decrypt(base64.b64decode(token), tenant_id) |
| from flask import current_app | from flask import current_app | ||||
| from langchain.embeddings import OpenAIEmbeddings | |||||
| from core.embedding.cached_embedding import CacheEmbedding | from core.embedding.cached_embedding import CacheEmbedding | ||||
| from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex, KeywordTableConfig | from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex, KeywordTableConfig | ||||
| from core.index.vector_index.vector_index import VectorIndex | from core.index.vector_index.vector_index import VectorIndex | ||||
| from core.llm.llm_builder import LLMBuilder | |||||
| from core.model_providers.model_factory import ModelFactory | |||||
| from models.dataset import Dataset | from models.dataset import Dataset | ||||
| if not ignore_high_quality_check and dataset.indexing_technique != 'high_quality': | if not ignore_high_quality_check and dataset.indexing_technique != 'high_quality': | ||||
| return None | return None | ||||
| model_credentials = LLMBuilder.get_model_credentials( | |||||
| tenant_id=dataset.tenant_id, | |||||
| model_provider=LLMBuilder.get_default_provider(dataset.tenant_id, 'text-embedding-ada-002'), | |||||
| model_name='text-embedding-ada-002' | |||||
| embedding_model = ModelFactory.get_embedding_model( | |||||
| tenant_id=dataset.tenant_id | |||||
| ) | ) | ||||
| embeddings = CacheEmbedding(OpenAIEmbeddings( | |||||
| max_retries=1, | |||||
| **model_credentials | |||||
| )) | |||||
| embeddings = CacheEmbedding(embedding_model) | |||||
| return VectorIndex( | return VectorIndex( | ||||
| dataset=dataset, | dataset=dataset, |
| import concurrent | |||||
| import datetime | import datetime | ||||
| import json | import json | ||||
| import logging | import logging | ||||
| import threading | import threading | ||||
| import time | import time | ||||
| import uuid | import uuid | ||||
| from concurrent.futures import ThreadPoolExecutor | |||||
| from typing import Optional, List, cast | from typing import Optional, List, cast | ||||
| from flask_login import current_user | from flask_login import current_user | ||||
| from core.docstore.dataset_docstore import DatesetDocumentStore | from core.docstore.dataset_docstore import DatesetDocumentStore | ||||
| from core.generator.llm_generator import LLMGenerator | from core.generator.llm_generator import LLMGenerator | ||||
| from core.index.index import IndexBuilder | from core.index.index import IndexBuilder | ||||
| from core.llm.error import ProviderTokenNotInitError | |||||
| from core.llm.llm_builder import LLMBuilder | |||||
| from core.llm.streamable_open_ai import StreamableOpenAI | |||||
| from core.model_providers.error import ProviderTokenNotInitError | |||||
| from core.model_providers.model_factory import ModelFactory | |||||
| from core.model_providers.models.entity.message import MessageType | |||||
| from core.spiltter.fixed_text_splitter import FixedRecursiveCharacterTextSplitter | from core.spiltter.fixed_text_splitter import FixedRecursiveCharacterTextSplitter | ||||
| from core.llm.token_calculator import TokenCalculator | |||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| from extensions.ext_redis import redis_client | from extensions.ext_redis import redis_client | ||||
| from extensions.ext_storage import storage | from extensions.ext_storage import storage | ||||
| class IndexingRunner: | class IndexingRunner: | ||||
| def __init__(self, embedding_model_name: str = "text-embedding-ada-002"): | |||||
| def __init__(self): | |||||
| self.storage = storage | self.storage = storage | ||||
| self.embedding_model_name = embedding_model_name | |||||
| def run(self, dataset_documents: List[DatasetDocument]): | def run(self, dataset_documents: List[DatasetDocument]): | ||||
| """Run the indexing process.""" | """Run the indexing process.""" | ||||
| dataset_document.stopped_at = datetime.datetime.utcnow() | dataset_document.stopped_at = datetime.datetime.utcnow() | ||||
| db.session.commit() | db.session.commit() | ||||
| def file_indexing_estimate(self, file_details: List[UploadFile], tmp_processing_rule: dict, | |||||
| def file_indexing_estimate(self, tenant_id: str, file_details: List[UploadFile], tmp_processing_rule: dict, | |||||
| doc_form: str = None) -> dict: | doc_form: str = None) -> dict: | ||||
| """ | """ | ||||
| Estimate the indexing for the document. | Estimate the indexing for the document. | ||||
| """ | """ | ||||
| embedding_model = ModelFactory.get_embedding_model( | |||||
| tenant_id=tenant_id | |||||
| ) | |||||
| tokens = 0 | tokens = 0 | ||||
| preview_texts = [] | preview_texts = [] | ||||
| total_segments = 0 | total_segments = 0 | ||||
| splitter=splitter, | splitter=splitter, | ||||
| processing_rule=processing_rule | processing_rule=processing_rule | ||||
| ) | ) | ||||
| total_segments += len(documents) | total_segments += len(documents) | ||||
| for document in documents: | for document in documents: | ||||
| if len(preview_texts) < 5: | if len(preview_texts) < 5: | ||||
| preview_texts.append(document.page_content) | preview_texts.append(document.page_content) | ||||
| tokens += TokenCalculator.get_num_tokens(self.embedding_model_name, | |||||
| self.filter_string(document.page_content)) | |||||
| tokens += embedding_model.get_num_tokens(self.filter_string(document.page_content)) | |||||
| text_generation_model = ModelFactory.get_text_generation_model( | |||||
| tenant_id=tenant_id | |||||
| ) | |||||
| if doc_form and doc_form == 'qa_model': | if doc_form and doc_form == 'qa_model': | ||||
| if len(preview_texts) > 0: | if len(preview_texts) > 0: | ||||
| # qa model document | # qa model document | ||||
| llm: StreamableOpenAI = LLMBuilder.to_llm( | |||||
| tenant_id=current_user.current_tenant_id, | |||||
| model_name='gpt-3.5-turbo', | |||||
| max_tokens=2000 | |||||
| ) | |||||
| response = LLMGenerator.generate_qa_document_sync(llm, preview_texts[0]) | |||||
| response = LLMGenerator.generate_qa_document(current_user.current_tenant_id, preview_texts[0]) | |||||
| document_qa_list = self.format_split_text(response) | document_qa_list = self.format_split_text(response) | ||||
| return { | return { | ||||
| "total_segments": total_segments * 20, | "total_segments": total_segments * 20, | ||||
| "tokens": total_segments * 2000, | "tokens": total_segments * 2000, | ||||
| "total_price": '{:f}'.format( | "total_price": '{:f}'.format( | ||||
| TokenCalculator.get_token_price('gpt-3.5-turbo', total_segments * 2000, 'completion')), | |||||
| "currency": TokenCalculator.get_currency(self.embedding_model_name), | |||||
| text_generation_model.get_token_price(total_segments * 2000, MessageType.HUMAN)), | |||||
| "currency": embedding_model.get_currency(), | |||||
| "qa_preview": document_qa_list, | "qa_preview": document_qa_list, | ||||
| "preview": preview_texts | "preview": preview_texts | ||||
| } | } | ||||
| return { | return { | ||||
| "total_segments": total_segments, | "total_segments": total_segments, | ||||
| "tokens": tokens, | "tokens": tokens, | ||||
| "total_price": '{:f}'.format(TokenCalculator.get_token_price(self.embedding_model_name, tokens)), | |||||
| "currency": TokenCalculator.get_currency(self.embedding_model_name), | |||||
| "total_price": '{:f}'.format(embedding_model.get_token_price(tokens)), | |||||
| "currency": embedding_model.get_currency(), | |||||
| "preview": preview_texts | "preview": preview_texts | ||||
| } | } | ||||
| def notion_indexing_estimate(self, notion_info_list: list, tmp_processing_rule: dict, doc_form: str = None) -> dict: | |||||
| def notion_indexing_estimate(self, tenant_id: str, notion_info_list: list, tmp_processing_rule: dict, doc_form: str = None) -> dict: | |||||
| """ | """ | ||||
| Estimate the indexing for the document. | Estimate the indexing for the document. | ||||
| """ | """ | ||||
| embedding_model = ModelFactory.get_embedding_model( | |||||
| tenant_id=tenant_id | |||||
| ) | |||||
| # load data from notion | # load data from notion | ||||
| tokens = 0 | tokens = 0 | ||||
| preview_texts = [] | preview_texts = [] | ||||
| if len(preview_texts) < 5: | if len(preview_texts) < 5: | ||||
| preview_texts.append(document.page_content) | preview_texts.append(document.page_content) | ||||
| tokens += TokenCalculator.get_num_tokens(self.embedding_model_name, document.page_content) | |||||
| tokens += embedding_model.get_num_tokens(document.page_content) | |||||
| text_generation_model = ModelFactory.get_text_generation_model( | |||||
| tenant_id=tenant_id | |||||
| ) | |||||
| if doc_form and doc_form == 'qa_model': | if doc_form and doc_form == 'qa_model': | ||||
| if len(preview_texts) > 0: | if len(preview_texts) > 0: | ||||
| # qa model document | # qa model document | ||||
| llm: StreamableOpenAI = LLMBuilder.to_llm( | |||||
| tenant_id=current_user.current_tenant_id, | |||||
| model_name='gpt-3.5-turbo', | |||||
| max_tokens=2000 | |||||
| ) | |||||
| response = LLMGenerator.generate_qa_document_sync(llm, preview_texts[0]) | |||||
| response = LLMGenerator.generate_qa_document(current_user.current_tenant_id, preview_texts[0]) | |||||
| document_qa_list = self.format_split_text(response) | document_qa_list = self.format_split_text(response) | ||||
| return { | return { | ||||
| "total_segments": total_segments * 20, | "total_segments": total_segments * 20, | ||||
| "tokens": total_segments * 2000, | "tokens": total_segments * 2000, | ||||
| "total_price": '{:f}'.format( | "total_price": '{:f}'.format( | ||||
| TokenCalculator.get_token_price('gpt-3.5-turbo', total_segments * 2000, 'completion')), | |||||
| "currency": TokenCalculator.get_currency(self.embedding_model_name), | |||||
| text_generation_model.get_token_price(total_segments * 2000, MessageType.HUMAN)), | |||||
| "currency": embedding_model.get_currency(), | |||||
| "qa_preview": document_qa_list, | "qa_preview": document_qa_list, | ||||
| "preview": preview_texts | "preview": preview_texts | ||||
| } | } | ||||
| return { | return { | ||||
| "total_segments": total_segments, | "total_segments": total_segments, | ||||
| "tokens": tokens, | "tokens": tokens, | ||||
| "total_price": '{:f}'.format(TokenCalculator.get_token_price(self.embedding_model_name, tokens)), | |||||
| "currency": TokenCalculator.get_currency(self.embedding_model_name), | |||||
| "total_price": '{:f}'.format(embedding_model.get_token_price(tokens)), | |||||
| "currency": embedding_model.get_currency(), | |||||
| "preview": preview_texts | "preview": preview_texts | ||||
| } | } | ||||
| doc_store = DatesetDocumentStore( | doc_store = DatesetDocumentStore( | ||||
| dataset=dataset, | dataset=dataset, | ||||
| user_id=dataset_document.created_by, | user_id=dataset_document.created_by, | ||||
| embedding_model_name=self.embedding_model_name, | |||||
| document_id=dataset_document.id | document_id=dataset_document.id | ||||
| ) | ) | ||||
| all_documents.extend(split_documents) | all_documents.extend(split_documents) | ||||
| # processing qa document | # processing qa document | ||||
| if document_form == 'qa_model': | if document_form == 'qa_model': | ||||
| llm: StreamableOpenAI = LLMBuilder.to_llm( | |||||
| tenant_id=tenant_id, | |||||
| model_name='gpt-3.5-turbo', | |||||
| max_tokens=2000 | |||||
| ) | |||||
| for i in range(0, len(all_documents), 10): | for i in range(0, len(all_documents), 10): | ||||
| threads = [] | threads = [] | ||||
| sub_documents = all_documents[i:i + 10] | sub_documents = all_documents[i:i + 10] | ||||
| for doc in sub_documents: | for doc in sub_documents: | ||||
| document_format_thread = threading.Thread(target=self.format_qa_document, kwargs={ | document_format_thread = threading.Thread(target=self.format_qa_document, kwargs={ | ||||
| 'llm': llm, 'document_node': doc, 'all_qa_documents': all_qa_documents}) | |||||
| 'tenant_id': tenant_id, 'document_node': doc, 'all_qa_documents': all_qa_documents}) | |||||
| threads.append(document_format_thread) | threads.append(document_format_thread) | ||||
| document_format_thread.start() | document_format_thread.start() | ||||
| for thread in threads: | for thread in threads: | ||||
| return all_qa_documents | return all_qa_documents | ||||
| return all_documents | return all_documents | ||||
| def format_qa_document(self, llm: StreamableOpenAI, document_node, all_qa_documents): | |||||
| def format_qa_document(self, tenant_id: str, document_node, all_qa_documents): | |||||
| format_documents = [] | format_documents = [] | ||||
| if document_node.page_content is None or not document_node.page_content.strip(): | if document_node.page_content is None or not document_node.page_content.strip(): | ||||
| return | return | ||||
| try: | try: | ||||
| # qa model document | # qa model document | ||||
| response = LLMGenerator.generate_qa_document_sync(llm, document_node.page_content) | |||||
| response = LLMGenerator.generate_qa_document(tenant_id, document_node.page_content) | |||||
| document_qa_list = self.format_split_text(response) | document_qa_list = self.format_split_text(response) | ||||
| qa_documents = [] | qa_documents = [] | ||||
| for result in document_qa_list: | for result in document_qa_list: | ||||
| vector_index = IndexBuilder.get_index(dataset, 'high_quality') | vector_index = IndexBuilder.get_index(dataset, 'high_quality') | ||||
| keyword_table_index = IndexBuilder.get_index(dataset, 'economy') | keyword_table_index = IndexBuilder.get_index(dataset, 'economy') | ||||
| embedding_model = ModelFactory.get_embedding_model( | |||||
| tenant_id=dataset.tenant_id | |||||
| ) | |||||
| # chunk nodes by chunk size | # chunk nodes by chunk size | ||||
| indexing_start_at = time.perf_counter() | indexing_start_at = time.perf_counter() | ||||
| tokens = 0 | tokens = 0 | ||||
| chunk_documents = documents[i:i + chunk_size] | chunk_documents = documents[i:i + chunk_size] | ||||
| tokens += sum( | tokens += sum( | ||||
| TokenCalculator.get_num_tokens(self.embedding_model_name, document.page_content) | |||||
| embedding_model.get_num_tokens(document.page_content) | |||||
| for document in chunk_documents | for document in chunk_documents | ||||
| ) | ) | ||||
| from typing import Union, Optional, List | |||||
| from langchain.callbacks.base import BaseCallbackHandler | |||||
| from core.constant import llm_constant | |||||
| from core.llm.error import ProviderTokenNotInitError | |||||
| from core.llm.provider.base import BaseProvider | |||||
| from core.llm.provider.llm_provider_service import LLMProviderService | |||||
| from core.llm.streamable_azure_chat_open_ai import StreamableAzureChatOpenAI | |||||
| from core.llm.streamable_azure_open_ai import StreamableAzureOpenAI | |||||
| from core.llm.streamable_chat_anthropic import StreamableChatAnthropic | |||||
| from core.llm.streamable_chat_open_ai import StreamableChatOpenAI | |||||
| from core.llm.streamable_open_ai import StreamableOpenAI | |||||
| from models.provider import ProviderType, ProviderName | |||||
| class LLMBuilder: | |||||
| """ | |||||
| This class handles the following logic: | |||||
| 1. For providers with the name 'OpenAI', the OPENAI_API_KEY value is stored directly in encrypted_config. | |||||
| 2. For providers with the name 'Azure OpenAI', encrypted_config stores the serialized values of four fields, as shown below: | |||||
| OPENAI_API_TYPE=azure | |||||
| OPENAI_API_VERSION=2022-12-01 | |||||
| OPENAI_API_BASE=https://your-resource-name.openai.azure.com | |||||
| OPENAI_API_KEY=<your Azure OpenAI API key> | |||||
| 3. For providers with the name 'Anthropic', the ANTHROPIC_API_KEY value is stored directly in encrypted_config. | |||||
| 4. For providers with the name 'Cohere', the COHERE_API_KEY value is stored directly in encrypted_config. | |||||
| 5. For providers with the name 'HUGGINGFACEHUB', the HUGGINGFACEHUB_API_KEY value is stored directly in encrypted_config. | |||||
| 6. Providers with the provider_type 'CUSTOM' can be created through the admin interface, while 'System' providers cannot be created through the admin interface. | |||||
| 7. If both CUSTOM and System providers exist in the records, the CUSTOM provider is preferred by default, but this preference can be changed via an input parameter. | |||||
| 8. For providers with the provider_type 'System', the quota_used must not exceed quota_limit. If the quota is exceeded, the provider cannot be used. Currently, only the TRIAL quota_type is supported, which is permanently non-resetting. | |||||
| """ | |||||
| @classmethod | |||||
| def to_llm(cls, tenant_id: str, model_name: str, **kwargs) -> Union[StreamableOpenAI, StreamableChatOpenAI]: | |||||
| provider = cls.get_default_provider(tenant_id, model_name) | |||||
| model_credentials = cls.get_model_credentials(tenant_id, provider, model_name) | |||||
| llm_cls = None | |||||
| mode = cls.get_mode_by_model(model_name) | |||||
| if mode == 'chat': | |||||
| if provider == ProviderName.OPENAI.value: | |||||
| llm_cls = StreamableChatOpenAI | |||||
| elif provider == ProviderName.AZURE_OPENAI.value: | |||||
| llm_cls = StreamableAzureChatOpenAI | |||||
| elif provider == ProviderName.ANTHROPIC.value: | |||||
| llm_cls = StreamableChatAnthropic | |||||
| elif mode == 'completion': | |||||
| if provider == ProviderName.OPENAI.value: | |||||
| llm_cls = StreamableOpenAI | |||||
| elif provider == ProviderName.AZURE_OPENAI.value: | |||||
| llm_cls = StreamableAzureOpenAI | |||||
| if not llm_cls: | |||||
| raise ValueError(f"model name {model_name} is not supported.") | |||||
| model_kwargs = { | |||||
| 'model_name': model_name, | |||||
| 'temperature': kwargs.get('temperature', 0), | |||||
| 'max_tokens': kwargs.get('max_tokens', 256), | |||||
| 'top_p': kwargs.get('top_p', 1), | |||||
| 'frequency_penalty': kwargs.get('frequency_penalty', 0), | |||||
| 'presence_penalty': kwargs.get('presence_penalty', 0), | |||||
| 'callbacks': kwargs.get('callbacks', None), | |||||
| 'streaming': kwargs.get('streaming', False), | |||||
| } | |||||
| model_kwargs.update(model_credentials) | |||||
| model_kwargs = llm_cls.get_kwargs_from_model_params(model_kwargs) | |||||
| return llm_cls(**model_kwargs) | |||||
| @classmethod | |||||
| def to_llm_from_model(cls, tenant_id: str, model: dict, streaming: bool = False, | |||||
| callbacks: Optional[List[BaseCallbackHandler]] = None) -> Union[StreamableOpenAI, StreamableChatOpenAI]: | |||||
| model_name = model.get("name") | |||||
| completion_params = model.get("completion_params", {}) | |||||
| return cls.to_llm( | |||||
| tenant_id=tenant_id, | |||||
| model_name=model_name, | |||||
| temperature=completion_params.get('temperature', 0), | |||||
| max_tokens=completion_params.get('max_tokens', 256), | |||||
| top_p=completion_params.get('top_p', 0), | |||||
| frequency_penalty=completion_params.get('frequency_penalty', 0.1), | |||||
| presence_penalty=completion_params.get('presence_penalty', 0.1), | |||||
| streaming=streaming, | |||||
| callbacks=callbacks | |||||
| ) | |||||
| @classmethod | |||||
| def get_mode_by_model(cls, model_name: str) -> str: | |||||
| if not model_name: | |||||
| raise ValueError(f"empty model name is not supported.") | |||||
| if model_name in llm_constant.models_by_mode['chat']: | |||||
| return "chat" | |||||
| elif model_name in llm_constant.models_by_mode['completion']: | |||||
| return "completion" | |||||
| else: | |||||
| raise ValueError(f"model name {model_name} is not supported.") | |||||
| @classmethod | |||||
| def get_model_credentials(cls, tenant_id: str, model_provider: str, model_name: str) -> dict: | |||||
| """ | |||||
| Returns the API credentials for the given tenant_id and model_name, based on the model's provider. | |||||
| Raises an exception if the model_name is not found or if the provider is not found. | |||||
| """ | |||||
| if not model_name: | |||||
| raise Exception('model name not found') | |||||
| # | |||||
| # if model_name not in llm_constant.models: | |||||
| # raise Exception('model {} not found'.format(model_name)) | |||||
| # model_provider = llm_constant.models[model_name] | |||||
| provider_service = LLMProviderService(tenant_id=tenant_id, provider_name=model_provider) | |||||
| return provider_service.get_credentials(model_name) | |||||
| @classmethod | |||||
| def get_default_provider(cls, tenant_id: str, model_name: str) -> str: | |||||
| provider_name = llm_constant.models[model_name] | |||||
| if provider_name == 'openai': | |||||
| # get the default provider (openai / azure_openai) for the tenant | |||||
| openai_provider = BaseProvider.get_valid_provider(tenant_id, ProviderName.OPENAI.value) | |||||
| azure_openai_provider = BaseProvider.get_valid_provider(tenant_id, ProviderName.AZURE_OPENAI.value) | |||||
| provider = None | |||||
| if openai_provider and openai_provider.provider_type == ProviderType.CUSTOM.value: | |||||
| provider = openai_provider | |||||
| elif azure_openai_provider and azure_openai_provider.provider_type == ProviderType.CUSTOM.value: | |||||
| provider = azure_openai_provider | |||||
| elif openai_provider and openai_provider.provider_type == ProviderType.SYSTEM.value: | |||||
| provider = openai_provider | |||||
| elif azure_openai_provider and azure_openai_provider.provider_type == ProviderType.SYSTEM.value: | |||||
| provider = azure_openai_provider | |||||
| if not provider: | |||||
| raise ProviderTokenNotInitError( | |||||
| f"No valid {provider_name} model provider credentials found. " | |||||
| f"Please go to Settings -> Model Provider to complete your provider credentials." | |||||
| ) | |||||
| provider_name = provider.provider_name | |||||
| return provider_name |
| import openai | |||||
| from models.provider import ProviderName | |||||
| class Moderation: | |||||
| def __init__(self, provider: str, api_key: str): | |||||
| self.provider = provider | |||||
| self.api_key = api_key | |||||
| if self.provider == ProviderName.OPENAI.value: | |||||
| self.client = openai.Moderation | |||||
| def moderate(self, text): | |||||
| return self.client.create(input=text, api_key=self.api_key) |
| import json | |||||
| import logging | |||||
| from typing import Optional, Union | |||||
| import anthropic | |||||
| from langchain.chat_models import ChatAnthropic | |||||
| from langchain.schema import HumanMessage | |||||
| from core import hosted_llm_credentials | |||||
| from core.llm.error import ProviderTokenNotInitError | |||||
| from core.llm.provider.base import BaseProvider | |||||
| from core.llm.provider.errors import ValidateFailedError | |||||
| from models.provider import ProviderName, ProviderType | |||||
| class AnthropicProvider(BaseProvider): | |||||
| def get_models(self, model_id: Optional[str] = None) -> list[dict]: | |||||
| return [ | |||||
| { | |||||
| 'id': 'claude-instant-1', | |||||
| 'name': 'claude-instant-1', | |||||
| }, | |||||
| { | |||||
| 'id': 'claude-2', | |||||
| 'name': 'claude-2', | |||||
| }, | |||||
| ] | |||||
| def get_credentials(self, model_id: Optional[str] = None) -> dict: | |||||
| return self.get_provider_api_key(model_id=model_id) | |||||
| def get_provider_name(self): | |||||
| return ProviderName.ANTHROPIC | |||||
| def get_provider_configs(self, obfuscated: bool = False, only_custom: bool = False) -> Union[str | dict]: | |||||
| """ | |||||
| Returns the provider configs. | |||||
| """ | |||||
| try: | |||||
| config = self.get_provider_api_key(only_custom=only_custom) | |||||
| except: | |||||
| config = { | |||||
| 'anthropic_api_key': '' | |||||
| } | |||||
| if obfuscated: | |||||
| if not config.get('anthropic_api_key'): | |||||
| config = { | |||||
| 'anthropic_api_key': '' | |||||
| } | |||||
| config['anthropic_api_key'] = self.obfuscated_token(config.get('anthropic_api_key')) | |||||
| return config | |||||
| return config | |||||
| def get_encrypted_token(self, config: Union[dict | str]): | |||||
| """ | |||||
| Returns the encrypted token. | |||||
| """ | |||||
| return json.dumps({ | |||||
| 'anthropic_api_key': self.encrypt_token(config['anthropic_api_key']) | |||||
| }) | |||||
| def get_decrypted_token(self, token: str): | |||||
| """ | |||||
| Returns the decrypted token. | |||||
| """ | |||||
| config = json.loads(token) | |||||
| config['anthropic_api_key'] = self.decrypt_token(config['anthropic_api_key']) | |||||
| return config | |||||
| def get_token_type(self): | |||||
| return dict | |||||
| def config_validate(self, config: Union[dict | str]): | |||||
| """ | |||||
| Validates the given config. | |||||
| """ | |||||
| # check OpenAI / Azure OpenAI credential is valid | |||||
| openai_provider = BaseProvider.get_valid_provider(self.tenant_id, ProviderName.OPENAI.value) | |||||
| azure_openai_provider = BaseProvider.get_valid_provider(self.tenant_id, ProviderName.AZURE_OPENAI.value) | |||||
| provider = None | |||||
| if openai_provider: | |||||
| provider = openai_provider | |||||
| elif azure_openai_provider: | |||||
| provider = azure_openai_provider | |||||
| if not provider: | |||||
| raise ValidateFailedError(f"OpenAI or Azure OpenAI provider must be configured first.") | |||||
| if provider.provider_type == ProviderType.SYSTEM.value: | |||||
| quota_used = provider.quota_used if provider.quota_used is not None else 0 | |||||
| quota_limit = provider.quota_limit if provider.quota_limit is not None else 0 | |||||
| if quota_used >= quota_limit: | |||||
| raise ValidateFailedError(f"Your quota for Dify Hosted OpenAI has been exhausted, " | |||||
| f"please configure OpenAI or Azure OpenAI provider first.") | |||||
| try: | |||||
| if not isinstance(config, dict): | |||||
| raise ValueError('Config must be a object.') | |||||
| if 'anthropic_api_key' not in config: | |||||
| raise ValueError('anthropic_api_key must be provided.') | |||||
| chat_llm = ChatAnthropic( | |||||
| model='claude-instant-1', | |||||
| anthropic_api_key=config['anthropic_api_key'], | |||||
| max_tokens_to_sample=10, | |||||
| temperature=0, | |||||
| default_request_timeout=60 | |||||
| ) | |||||
| messages = [ | |||||
| HumanMessage( | |||||
| content="ping" | |||||
| ) | |||||
| ] | |||||
| chat_llm(messages) | |||||
| except anthropic.APIConnectionError as ex: | |||||
| raise ValidateFailedError(f"Anthropic: Connection error, cause: {ex.__cause__}") | |||||
| except (anthropic.APIStatusError, anthropic.RateLimitError) as ex: | |||||
| raise ValidateFailedError(f"Anthropic: Error code: {ex.status_code} - " | |||||
| f"{ex.body['error']['type']}: {ex.body['error']['message']}") | |||||
| except Exception as ex: | |||||
| logging.exception('Anthropic config validation failed') | |||||
| raise ex | |||||
| def get_hosted_credentials(self) -> Union[str | dict]: | |||||
| if not hosted_llm_credentials.anthropic or not hosted_llm_credentials.anthropic.api_key: | |||||
| raise ProviderTokenNotInitError( | |||||
| f"No valid {self.get_provider_name().value} model provider credentials found. " | |||||
| f"Please go to Settings -> Model Provider to complete your provider credentials." | |||||
| ) | |||||
| return {'anthropic_api_key': hosted_llm_credentials.anthropic.api_key} |
| import json | |||||
| import logging | |||||
| from typing import Optional, Union | |||||
| import openai | |||||
| import requests | |||||
| from core.llm.provider.base import BaseProvider | |||||
| from core.llm.provider.errors import ValidateFailedError | |||||
| from models.provider import ProviderName | |||||
| AZURE_OPENAI_API_VERSION = '2023-07-01-preview' | |||||
| class AzureProvider(BaseProvider): | |||||
| def get_models(self, model_id: Optional[str] = None, credentials: Optional[dict] = None) -> list[dict]: | |||||
| return [] | |||||
| def check_embedding_model(self, credentials: Optional[dict] = None): | |||||
| credentials = self.get_credentials('text-embedding-ada-002') if not credentials else credentials | |||||
| try: | |||||
| result = openai.Embedding.create(input=['test'], | |||||
| engine='text-embedding-ada-002', | |||||
| timeout=60, | |||||
| api_key=str(credentials.get('openai_api_key')), | |||||
| api_base=str(credentials.get('openai_api_base')), | |||||
| api_type='azure', | |||||
| api_version=str(credentials.get('openai_api_version')))["data"][0][ | |||||
| "embedding"] | |||||
| except openai.error.AuthenticationError as e: | |||||
| raise AzureAuthenticationError(str(e)) | |||||
| except openai.error.APIConnectionError as e: | |||||
| raise AzureRequestFailedError( | |||||
| 'Failed to request Azure OpenAI, please check your API Base Endpoint, The format is `https://xxx.openai.azure.com/`') | |||||
| except openai.error.InvalidRequestError as e: | |||||
| if e.http_status == 404: | |||||
| raise AzureRequestFailedError("Please check your 'gpt-3.5-turbo' or 'text-embedding-ada-002' " | |||||
| "deployment name is exists in Azure AI") | |||||
| else: | |||||
| raise AzureRequestFailedError( | |||||
| 'Failed to request Azure OpenAI. cause: {}'.format(str(e))) | |||||
| except openai.error.OpenAIError as e: | |||||
| raise AzureRequestFailedError( | |||||
| 'Failed to request Azure OpenAI. cause: {}'.format(str(e))) | |||||
| if not isinstance(result, list): | |||||
| raise AzureRequestFailedError('Failed to request Azure OpenAI.') | |||||
| def get_credentials(self, model_id: Optional[str] = None) -> dict: | |||||
| """ | |||||
| Returns the API credentials for Azure OpenAI as a dictionary. | |||||
| """ | |||||
| config = self.get_provider_api_key(model_id=model_id) | |||||
| config['openai_api_type'] = 'azure' | |||||
| config['openai_api_version'] = AZURE_OPENAI_API_VERSION | |||||
| if model_id == 'text-embedding-ada-002': | |||||
| config['deployment'] = model_id.replace('.', '') if model_id else None | |||||
| config['chunk_size'] = 16 | |||||
| else: | |||||
| config['deployment_name'] = model_id.replace('.', '') if model_id else None | |||||
| return config | |||||
| def get_provider_name(self): | |||||
| return ProviderName.AZURE_OPENAI | |||||
| def get_provider_configs(self, obfuscated: bool = False, only_custom: bool = False) -> Union[str | dict]: | |||||
| """ | |||||
| Returns the provider configs. | |||||
| """ | |||||
| try: | |||||
| config = self.get_provider_api_key(only_custom=only_custom) | |||||
| except: | |||||
| config = { | |||||
| 'openai_api_type': 'azure', | |||||
| 'openai_api_version': AZURE_OPENAI_API_VERSION, | |||||
| 'openai_api_base': '', | |||||
| 'openai_api_key': '' | |||||
| } | |||||
| if obfuscated: | |||||
| if not config.get('openai_api_key'): | |||||
| config = { | |||||
| 'openai_api_type': 'azure', | |||||
| 'openai_api_version': AZURE_OPENAI_API_VERSION, | |||||
| 'openai_api_base': '', | |||||
| 'openai_api_key': '' | |||||
| } | |||||
| config['openai_api_key'] = self.obfuscated_token(config.get('openai_api_key')) | |||||
| return config | |||||
| return config | |||||
| def get_token_type(self): | |||||
| return dict | |||||
| def config_validate(self, config: Union[dict | str]): | |||||
| """ | |||||
| Validates the given config. | |||||
| """ | |||||
| try: | |||||
| if not isinstance(config, dict): | |||||
| raise ValueError('Config must be a object.') | |||||
| if 'openai_api_version' not in config: | |||||
| config['openai_api_version'] = AZURE_OPENAI_API_VERSION | |||||
| self.check_embedding_model(credentials=config) | |||||
| except ValidateFailedError as e: | |||||
| raise e | |||||
| except AzureAuthenticationError: | |||||
| raise ValidateFailedError('Validation failed, please check your API Key.') | |||||
| except AzureRequestFailedError as ex: | |||||
| raise ValidateFailedError('Validation failed, error: {}.'.format(str(ex))) | |||||
| except Exception as ex: | |||||
| logging.exception('Azure OpenAI Credentials validation failed') | |||||
| raise ValidateFailedError('Validation failed, error: {}.'.format(str(ex))) | |||||
| def get_encrypted_token(self, config: Union[dict | str]): | |||||
| """ | |||||
| Returns the encrypted token. | |||||
| """ | |||||
| return json.dumps({ | |||||
| 'openai_api_type': 'azure', | |||||
| 'openai_api_version': AZURE_OPENAI_API_VERSION, | |||||
| 'openai_api_base': config['openai_api_base'], | |||||
| 'openai_api_key': self.encrypt_token(config['openai_api_key']) | |||||
| }) | |||||
| def get_decrypted_token(self, token: str): | |||||
| """ | |||||
| Returns the decrypted token. | |||||
| """ | |||||
| config = json.loads(token) | |||||
| config['openai_api_key'] = self.decrypt_token(config['openai_api_key']) | |||||
| return config | |||||
| class AzureAuthenticationError(Exception): | |||||
| pass | |||||
| class AzureRequestFailedError(Exception): | |||||
| pass |
| import base64 | |||||
| from abc import ABC, abstractmethod | |||||
| from typing import Optional, Union | |||||
| from core.constant import llm_constant | |||||
| from core.llm.error import QuotaExceededError, ModelCurrentlyNotSupportError, ProviderTokenNotInitError | |||||
| from extensions.ext_database import db | |||||
| from libs import rsa | |||||
| from models.account import Tenant | |||||
| from models.provider import Provider, ProviderType, ProviderName | |||||
| class BaseProvider(ABC): | |||||
| def __init__(self, tenant_id: str): | |||||
| self.tenant_id = tenant_id | |||||
| def get_provider_api_key(self, model_id: Optional[str] = None, only_custom: bool = False) -> Union[str | dict]: | |||||
| """ | |||||
| Returns the decrypted API key for the given tenant_id and provider_name. | |||||
| If the provider is of type SYSTEM and the quota is exceeded, raises a QuotaExceededError. | |||||
| If the provider is not found or not valid, raises a ProviderTokenNotInitError. | |||||
| """ | |||||
| provider = self.get_provider(only_custom) | |||||
| if not provider: | |||||
| raise ProviderTokenNotInitError( | |||||
| f"No valid {llm_constant.models[model_id]} model provider credentials found. " | |||||
| f"Please go to Settings -> Model Provider to complete your provider credentials." | |||||
| ) | |||||
| if provider.provider_type == ProviderType.SYSTEM.value: | |||||
| quota_used = provider.quota_used if provider.quota_used is not None else 0 | |||||
| quota_limit = provider.quota_limit if provider.quota_limit is not None else 0 | |||||
| if model_id and model_id == 'gpt-4': | |||||
| raise ModelCurrentlyNotSupportError() | |||||
| if quota_used >= quota_limit: | |||||
| raise QuotaExceededError() | |||||
| return self.get_hosted_credentials() | |||||
| else: | |||||
| return self.get_decrypted_token(provider.encrypted_config) | |||||
| def get_provider(self, only_custom: bool = False) -> Optional[Provider]: | |||||
| """ | |||||
| Returns the Provider instance for the given tenant_id and provider_name. | |||||
| If both CUSTOM and System providers exist, the preferred provider will be returned based on the prefer_custom flag. | |||||
| """ | |||||
| return BaseProvider.get_valid_provider(self.tenant_id, self.get_provider_name().value, only_custom) | |||||
| @classmethod | |||||
| def get_valid_provider(cls, tenant_id: str, provider_name: str = None, only_custom: bool = False) -> Optional[ | |||||
| Provider]: | |||||
| """ | |||||
| Returns the Provider instance for the given tenant_id and provider_name. | |||||
| If both CUSTOM and System providers exist. | |||||
| """ | |||||
| query = db.session.query(Provider).filter( | |||||
| Provider.tenant_id == tenant_id | |||||
| ) | |||||
| if provider_name: | |||||
| query = query.filter(Provider.provider_name == provider_name) | |||||
| if only_custom: | |||||
| query = query.filter(Provider.provider_type == ProviderType.CUSTOM.value) | |||||
| providers = query.order_by(Provider.provider_type.asc()).all() | |||||
| for provider in providers: | |||||
| if provider.provider_type == ProviderType.CUSTOM.value and provider.is_valid and provider.encrypted_config: | |||||
| return provider | |||||
| elif provider.provider_type == ProviderType.SYSTEM.value and provider.is_valid: | |||||
| return provider | |||||
| return None | |||||
| def get_hosted_credentials(self) -> Union[str | dict]: | |||||
| raise ProviderTokenNotInitError( | |||||
| f"No valid {self.get_provider_name().value} model provider credentials found. " | |||||
| f"Please go to Settings -> Model Provider to complete your provider credentials." | |||||
| ) | |||||
| def get_provider_configs(self, obfuscated: bool = False, only_custom: bool = False) -> Union[str | dict]: | |||||
| """ | |||||
| Returns the provider configs. | |||||
| """ | |||||
| try: | |||||
| config = self.get_provider_api_key(only_custom=only_custom) | |||||
| except: | |||||
| config = '' | |||||
| if obfuscated: | |||||
| return self.obfuscated_token(config) | |||||
| return config | |||||
| def obfuscated_token(self, token: str): | |||||
| return token[:6] + '*' * (len(token) - 8) + token[-2:] | |||||
| def get_token_type(self): | |||||
| return str | |||||
| def get_encrypted_token(self, config: Union[dict | str]): | |||||
| return self.encrypt_token(config) | |||||
| def get_decrypted_token(self, token: str): | |||||
| return self.decrypt_token(token) | |||||
| def encrypt_token(self, token): | |||||
| tenant = db.session.query(Tenant).filter(Tenant.id == self.tenant_id).first() | |||||
| encrypted_token = rsa.encrypt(token, tenant.encrypt_public_key) | |||||
| return base64.b64encode(encrypted_token).decode() | |||||
| def decrypt_token(self, token): | |||||
| return rsa.decrypt(base64.b64decode(token), self.tenant_id) | |||||
| @abstractmethod | |||||
| def get_provider_name(self): | |||||
| raise NotImplementedError | |||||
| @abstractmethod | |||||
| def get_credentials(self, model_id: Optional[str] = None) -> dict: | |||||
| raise NotImplementedError | |||||
| @abstractmethod | |||||
| def get_models(self, model_id: Optional[str] = None) -> list[dict]: | |||||
| raise NotImplementedError | |||||
| @abstractmethod | |||||
| def config_validate(self, config: str): | |||||
| raise NotImplementedError |
| class ValidateFailedError(Exception): | |||||
| description = "Provider Validate failed" |
| from typing import Optional | |||||
| from core.llm.provider.base import BaseProvider | |||||
| from models.provider import ProviderName | |||||
| class HuggingfaceProvider(BaseProvider): | |||||
| def get_models(self, model_id: Optional[str] = None) -> list[dict]: | |||||
| credentials = self.get_credentials(model_id) | |||||
| # todo | |||||
| return [] | |||||
| def get_credentials(self, model_id: Optional[str] = None) -> dict: | |||||
| """ | |||||
| Returns the API credentials for Huggingface as a dictionary, for the given tenant_id. | |||||
| """ | |||||
| return { | |||||
| 'huggingface_api_key': self.get_provider_api_key(model_id=model_id) | |||||
| } | |||||
| def get_provider_name(self): | |||||
| return ProviderName.HUGGINGFACEHUB |
| from typing import Optional, Union | |||||
| from core.llm.provider.anthropic_provider import AnthropicProvider | |||||
| from core.llm.provider.azure_provider import AzureProvider | |||||
| from core.llm.provider.base import BaseProvider | |||||
| from core.llm.provider.huggingface_provider import HuggingfaceProvider | |||||
| from core.llm.provider.openai_provider import OpenAIProvider | |||||
| from models.provider import Provider | |||||
| class LLMProviderService: | |||||
| def __init__(self, tenant_id: str, provider_name: str): | |||||
| self.provider = self.init_provider(tenant_id, provider_name) | |||||
| def init_provider(self, tenant_id: str, provider_name: str) -> BaseProvider: | |||||
| if provider_name == 'openai': | |||||
| return OpenAIProvider(tenant_id) | |||||
| elif provider_name == 'azure_openai': | |||||
| return AzureProvider(tenant_id) | |||||
| elif provider_name == 'anthropic': | |||||
| return AnthropicProvider(tenant_id) | |||||
| elif provider_name == 'huggingface': | |||||
| return HuggingfaceProvider(tenant_id) | |||||
| else: | |||||
| raise Exception('provider {} not found'.format(provider_name)) | |||||
| def get_models(self, model_id: Optional[str] = None) -> list[dict]: | |||||
| return self.provider.get_models(model_id) | |||||
| def get_credentials(self, model_id: Optional[str] = None) -> dict: | |||||
| return self.provider.get_credentials(model_id) | |||||
| def get_provider_configs(self, obfuscated: bool = False, only_custom: bool = False) -> Union[str | dict]: | |||||
| return self.provider.get_provider_configs(obfuscated=obfuscated, only_custom=only_custom) | |||||
| def get_provider_db_record(self) -> Optional[Provider]: | |||||
| return self.provider.get_provider() | |||||
| def config_validate(self, config: Union[dict | str]): | |||||
| """ | |||||
| Validates the given config. | |||||
| :param config: | |||||
| :raises: ValidateFailedError | |||||
| """ | |||||
| return self.provider.config_validate(config) | |||||
| def get_token_type(self): | |||||
| return self.provider.get_token_type() | |||||
| def get_encrypted_token(self, config: Union[dict | str]): | |||||
| return self.provider.get_encrypted_token(config) |
| import logging | |||||
| from typing import Optional, Union | |||||
| import openai | |||||
| from openai.error import AuthenticationError, OpenAIError | |||||
| from core import hosted_llm_credentials | |||||
| from core.llm.error import ProviderTokenNotInitError | |||||
| from core.llm.moderation import Moderation | |||||
| from core.llm.provider.base import BaseProvider | |||||
| from core.llm.provider.errors import ValidateFailedError | |||||
| from models.provider import ProviderName | |||||
| class OpenAIProvider(BaseProvider): | |||||
| def get_models(self, model_id: Optional[str] = None) -> list[dict]: | |||||
| credentials = self.get_credentials(model_id) | |||||
| response = openai.Model.list(**credentials) | |||||
| return [{ | |||||
| 'id': model['id'], | |||||
| 'name': model['id'], | |||||
| } for model in response['data']] | |||||
| def get_credentials(self, model_id: Optional[str] = None) -> dict: | |||||
| """ | |||||
| Returns the credentials for the given tenant_id and provider_name. | |||||
| """ | |||||
| return { | |||||
| 'openai_api_key': self.get_provider_api_key(model_id=model_id) | |||||
| } | |||||
| def get_provider_name(self): | |||||
| return ProviderName.OPENAI | |||||
| def config_validate(self, config: Union[dict | str]): | |||||
| """ | |||||
| Validates the given config. | |||||
| """ | |||||
| try: | |||||
| Moderation(self.get_provider_name().value, config).moderate('test') | |||||
| except (AuthenticationError, OpenAIError) as ex: | |||||
| raise ValidateFailedError(str(ex)) | |||||
| except Exception as ex: | |||||
| logging.exception('OpenAI config validation failed') | |||||
| raise ex | |||||
| def get_hosted_credentials(self) -> Union[str | dict]: | |||||
| if not hosted_llm_credentials.openai or not hosted_llm_credentials.openai.api_key: | |||||
| raise ProviderTokenNotInitError( | |||||
| f"No valid {self.get_provider_name().value} model provider credentials found. " | |||||
| f"Please go to Settings -> Model Provider to complete your provider credentials." | |||||
| ) | |||||
| return hosted_llm_credentials.openai.api_key |
| from typing import List, Optional, Any, Dict | |||||
| from httpx import Timeout | |||||
| from langchain.callbacks.manager import Callbacks | |||||
| from langchain.chat_models import ChatAnthropic | |||||
| from langchain.schema import BaseMessage, LLMResult, SystemMessage, AIMessage, HumanMessage, ChatMessage | |||||
| from pydantic import root_validator | |||||
| from core.llm.wrappers.anthropic_wrapper import handle_anthropic_exceptions | |||||
| class StreamableChatAnthropic(ChatAnthropic): | |||||
| """ | |||||
| Wrapper around Anthropic's large language model. | |||||
| """ | |||||
| default_request_timeout: Optional[float] = Timeout(timeout=300.0, connect=5.0) | |||||
| @root_validator() | |||||
| def prepare_params(cls, values: Dict) -> Dict: | |||||
| values['model_name'] = values.get('model') | |||||
| values['max_tokens'] = values.get('max_tokens_to_sample') | |||||
| return values | |||||
| @handle_anthropic_exceptions | |||||
| def generate( | |||||
| self, | |||||
| messages: List[List[BaseMessage]], | |||||
| stop: Optional[List[str]] = None, | |||||
| callbacks: Callbacks = None, | |||||
| *, | |||||
| tags: Optional[List[str]] = None, | |||||
| metadata: Optional[Dict[str, Any]] = None, | |||||
| **kwargs: Any, | |||||
| ) -> LLMResult: | |||||
| return super().generate(messages, stop, callbacks, tags=tags, metadata=metadata, **kwargs) | |||||
| @classmethod | |||||
| def get_kwargs_from_model_params(cls, params: dict): | |||||
| params['model'] = params.get('model_name') | |||||
| del params['model_name'] | |||||
| params['max_tokens_to_sample'] = params.get('max_tokens') | |||||
| del params['max_tokens'] | |||||
| del params['frequency_penalty'] | |||||
| del params['presence_penalty'] | |||||
| return params | |||||
| def _convert_one_message_to_text(self, message: BaseMessage) -> str: | |||||
| if isinstance(message, ChatMessage): | |||||
| message_text = f"\n\n{message.role.capitalize()}: {message.content}" | |||||
| elif isinstance(message, HumanMessage): | |||||
| message_text = f"{self.HUMAN_PROMPT} {message.content}" | |||||
| elif isinstance(message, AIMessage): | |||||
| message_text = f"{self.AI_PROMPT} {message.content}" | |||||
| elif isinstance(message, SystemMessage): | |||||
| message_text = f"<admin>{message.content}</admin>" | |||||
| else: | |||||
| raise ValueError(f"Got unknown type {message}") | |||||
| return message_text |
| import decimal | |||||
| from typing import Optional | |||||
| import tiktoken | |||||
| from core.constant import llm_constant | |||||
| class TokenCalculator: | |||||
| @classmethod | |||||
| def get_num_tokens(cls, model_name: str, text: str): | |||||
| if len(text) == 0: | |||||
| return 0 | |||||
| enc = tiktoken.encoding_for_model(model_name) | |||||
| tokenized_text = enc.encode(text) | |||||
| # calculate the number of tokens in the encoded text | |||||
| return len(tokenized_text) | |||||
| @classmethod | |||||
| def get_token_price(cls, model_name: str, tokens: int, text_type: Optional[str] = None) -> decimal.Decimal: | |||||
| if model_name in llm_constant.models_by_mode['embedding']: | |||||
| unit_price = llm_constant.model_prices[model_name]['usage'] | |||||
| elif text_type == 'prompt': | |||||
| unit_price = llm_constant.model_prices[model_name]['prompt'] | |||||
| elif text_type == 'completion': | |||||
| unit_price = llm_constant.model_prices[model_name]['completion'] | |||||
| else: | |||||
| raise Exception('Invalid text type') | |||||
| tokens_per_1k = (decimal.Decimal(tokens) / 1000).quantize(decimal.Decimal('0.001'), | |||||
| rounding=decimal.ROUND_HALF_UP) | |||||
| total_price = tokens_per_1k * unit_price | |||||
| return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP) | |||||
| @classmethod | |||||
| def get_currency(cls, model_name: str): | |||||
| return llm_constant.model_currency |
| import openai | |||||
| from core.llm.wrappers.openai_wrapper import handle_openai_exceptions | |||||
| from models.provider import ProviderName | |||||
| from core.llm.provider.base import BaseProvider | |||||
| class Whisper: | |||||
| def __init__(self, provider: BaseProvider): | |||||
| self.provider = provider | |||||
| if self.provider.get_provider_name() == ProviderName.OPENAI: | |||||
| self.client = openai.Audio | |||||
| self.credentials = provider.get_credentials() | |||||
| @handle_openai_exceptions | |||||
| def transcribe(self, file): | |||||
| return self.client.transcribe( | |||||
| model='whisper-1', | |||||
| file=file, | |||||
| api_key=self.credentials.get('openai_api_key'), | |||||
| api_base=self.credentials.get('openai_api_base'), | |||||
| api_type=self.credentials.get('openai_api_type'), | |||||
| api_version=self.credentials.get('openai_api_version'), | |||||
| ) |
| import logging | |||||
| from functools import wraps | |||||
| import anthropic | |||||
| from core.llm.error import LLMAPIConnectionError, LLMAPIUnavailableError, LLMRateLimitError, LLMAuthorizationError, \ | |||||
| LLMBadRequestError | |||||
| def handle_anthropic_exceptions(func): | |||||
| @wraps(func) | |||||
| def wrapper(*args, **kwargs): | |||||
| try: | |||||
| return func(*args, **kwargs) | |||||
| except anthropic.APIConnectionError as e: | |||||
| logging.exception("Failed to connect to Anthropic API.") | |||||
| raise LLMAPIConnectionError(f"Anthropic: The server could not be reached, cause: {e.__cause__}") | |||||
| except anthropic.RateLimitError: | |||||
| raise LLMRateLimitError("Anthropic: A 429 status code was received; we should back off a bit.") | |||||
| except anthropic.AuthenticationError as e: | |||||
| raise LLMAuthorizationError(f"Anthropic: {e.message}") | |||||
| except anthropic.BadRequestError as e: | |||||
| raise LLMBadRequestError(f"Anthropic: {e.message}") | |||||
| except anthropic.APIStatusError as e: | |||||
| raise LLMAPIUnavailableError(f"Anthropic: code: {e.status_code}, cause: {e.message}") | |||||
| return wrapper |
| import logging | |||||
| from functools import wraps | |||||
| import openai | |||||
| from core.llm.error import LLMAPIConnectionError, LLMAPIUnavailableError, LLMRateLimitError, LLMAuthorizationError, \ | |||||
| LLMBadRequestError | |||||
| def handle_openai_exceptions(func): | |||||
| @wraps(func) | |||||
| def wrapper(*args, **kwargs): | |||||
| try: | |||||
| return func(*args, **kwargs) | |||||
| except openai.error.InvalidRequestError as e: | |||||
| logging.exception("Invalid request to OpenAI API.") | |||||
| raise LLMBadRequestError(str(e)) | |||||
| except openai.error.APIConnectionError as e: | |||||
| logging.exception("Failed to connect to OpenAI API.") | |||||
| raise LLMAPIConnectionError(e.__class__.__name__ + ":" + str(e)) | |||||
| except (openai.error.APIError, openai.error.ServiceUnavailableError, openai.error.Timeout) as e: | |||||
| logging.exception("OpenAI service unavailable.") | |||||
| raise LLMAPIUnavailableError(e.__class__.__name__ + ":" + str(e)) | |||||
| except openai.error.RateLimitError as e: | |||||
| raise LLMRateLimitError(str(e)) | |||||
| except openai.error.AuthenticationError as e: | |||||
| raise LLMAuthorizationError(str(e)) | |||||
| except openai.error.OpenAIError as e: | |||||
| raise LLMBadRequestError(e.__class__.__name__ + ":" + str(e)) | |||||
| return wrapper |
| from typing import Any, List, Dict, Union | |||||
| from typing import Any, List, Dict | |||||
| from langchain.memory.chat_memory import BaseChatMemory | from langchain.memory.chat_memory import BaseChatMemory | ||||
| from langchain.schema import get_buffer_string, BaseMessage, HumanMessage, AIMessage, BaseLanguageModel | |||||
| from langchain.schema import get_buffer_string, BaseMessage | |||||
| from core.llm.streamable_chat_open_ai import StreamableChatOpenAI | |||||
| from core.llm.streamable_open_ai import StreamableOpenAI | |||||
| from core.model_providers.models.entity.message import PromptMessage, MessageType, to_lc_messages | |||||
| from core.model_providers.models.llm.base import BaseLLM | |||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| from models.model import Conversation, Message | from models.model import Conversation, Message | ||||
| conversation: Conversation | conversation: Conversation | ||||
| human_prefix: str = "Human" | human_prefix: str = "Human" | ||||
| ai_prefix: str = "Assistant" | ai_prefix: str = "Assistant" | ||||
| llm: BaseLanguageModel | |||||
| model_instance: BaseLLM | |||||
| memory_key: str = "chat_history" | memory_key: str = "chat_history" | ||||
| max_token_limit: int = 2000 | max_token_limit: int = 2000 | ||||
| message_limit: int = 10 | message_limit: int = 10 | ||||
| messages = list(reversed(messages)) | messages = list(reversed(messages)) | ||||
| chat_messages: List[BaseMessage] = [] | |||||
| chat_messages: List[PromptMessage] = [] | |||||
| for message in messages: | for message in messages: | ||||
| chat_messages.append(HumanMessage(content=message.query)) | |||||
| chat_messages.append(AIMessage(content=message.answer)) | |||||
| chat_messages.append(PromptMessage(content=message.query, type=MessageType.HUMAN)) | |||||
| chat_messages.append(PromptMessage(content=message.answer, type=MessageType.ASSISTANT)) | |||||
| if not chat_messages: | if not chat_messages: | ||||
| return chat_messages | |||||
| return [] | |||||
| # prune the chat message if it exceeds the max token limit | # prune the chat message if it exceeds the max token limit | ||||
| curr_buffer_length = self.llm.get_num_tokens_from_messages(chat_messages) | |||||
| curr_buffer_length = self.model_instance.get_num_tokens(chat_messages) | |||||
| if curr_buffer_length > self.max_token_limit: | if curr_buffer_length > self.max_token_limit: | ||||
| pruned_memory = [] | pruned_memory = [] | ||||
| while curr_buffer_length > self.max_token_limit and chat_messages: | while curr_buffer_length > self.max_token_limit and chat_messages: | ||||
| pruned_memory.append(chat_messages.pop(0)) | pruned_memory.append(chat_messages.pop(0)) | ||||
| curr_buffer_length = self.llm.get_num_tokens_from_messages(chat_messages) | |||||
| curr_buffer_length = self.model_instance.get_num_tokens(chat_messages) | |||||
| return chat_messages | |||||
| return to_lc_messages(chat_messages) | |||||
| @property | @property | ||||
| def memory_variables(self) -> List[str]: | def memory_variables(self) -> List[str]: |
| from typing import Optional | |||||
| from langchain.callbacks.base import Callbacks | |||||
| from core.model_providers.error import ProviderTokenNotInitError, LLMBadRequestError | |||||
| from core.model_providers.model_provider_factory import ModelProviderFactory, DEFAULT_MODELS | |||||
| from core.model_providers.models.base import BaseProviderModel | |||||
| from core.model_providers.models.embedding.base import BaseEmbedding | |||||
| from core.model_providers.models.entity.model_params import ModelKwargs, ModelType | |||||
| from core.model_providers.models.llm.base import BaseLLM | |||||
| from core.model_providers.models.speech2text.base import BaseSpeech2Text | |||||
| from extensions.ext_database import db | |||||
| from models.provider import TenantDefaultModel | |||||
| class ModelFactory: | |||||
| @classmethod | |||||
| def get_text_generation_model_from_model_config(cls, tenant_id: str, | |||||
| model_config: dict, | |||||
| streaming: bool = False, | |||||
| callbacks: Callbacks = None) -> Optional[BaseLLM]: | |||||
| provider_name = model_config.get("provider") | |||||
| model_name = model_config.get("name") | |||||
| completion_params = model_config.get("completion_params", {}) | |||||
| return cls.get_text_generation_model( | |||||
| tenant_id=tenant_id, | |||||
| model_provider_name=provider_name, | |||||
| model_name=model_name, | |||||
| model_kwargs=ModelKwargs( | |||||
| temperature=completion_params.get('temperature', 0), | |||||
| max_tokens=completion_params.get('max_tokens', 256), | |||||
| top_p=completion_params.get('top_p', 0), | |||||
| frequency_penalty=completion_params.get('frequency_penalty', 0.1), | |||||
| presence_penalty=completion_params.get('presence_penalty', 0.1) | |||||
| ), | |||||
| streaming=streaming, | |||||
| callbacks=callbacks | |||||
| ) | |||||
| @classmethod | |||||
| def get_text_generation_model(cls, | |||||
| tenant_id: str, | |||||
| model_provider_name: Optional[str] = None, | |||||
| model_name: Optional[str] = None, | |||||
| model_kwargs: Optional[ModelKwargs] = None, | |||||
| streaming: bool = False, | |||||
| callbacks: Callbacks = None) -> Optional[BaseLLM]: | |||||
| """ | |||||
| get text generation model. | |||||
| :param tenant_id: a string representing the ID of the tenant. | |||||
| :param model_provider_name: | |||||
| :param model_name: | |||||
| :param model_kwargs: | |||||
| :param streaming: | |||||
| :param callbacks: | |||||
| :return: | |||||
| """ | |||||
| is_default_model = False | |||||
| if model_provider_name is None and model_name is None: | |||||
| default_model = cls.get_default_model(tenant_id, ModelType.TEXT_GENERATION) | |||||
| if not default_model: | |||||
| raise LLMBadRequestError(f"Default model is not available. " | |||||
| f"Please configure a Default System Reasoning Model " | |||||
| f"in the Settings -> Model Provider.") | |||||
| model_provider_name = default_model.provider_name | |||||
| model_name = default_model.model_name | |||||
| is_default_model = True | |||||
| # get model provider | |||||
| model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name) | |||||
| if not model_provider: | |||||
| raise ProviderTokenNotInitError(f"Model {model_name} provider credentials is not initialized.") | |||||
| # init text generation model | |||||
| model_class = model_provider.get_model_class(model_type=ModelType.TEXT_GENERATION) | |||||
| try: | |||||
| model_instance = model_class( | |||||
| model_provider=model_provider, | |||||
| name=model_name, | |||||
| model_kwargs=model_kwargs, | |||||
| streaming=streaming, | |||||
| callbacks=callbacks | |||||
| ) | |||||
| except LLMBadRequestError as e: | |||||
| if is_default_model: | |||||
| raise LLMBadRequestError(f"Default model {model_name} is not available. " | |||||
| f"Please check your model provider credentials.") | |||||
| else: | |||||
| raise e | |||||
| if is_default_model: | |||||
| model_instance.deduct_quota = False | |||||
| return model_instance | |||||
| @classmethod | |||||
| def get_embedding_model(cls, | |||||
| tenant_id: str, | |||||
| model_provider_name: Optional[str] = None, | |||||
| model_name: Optional[str] = None) -> Optional[BaseEmbedding]: | |||||
| """ | |||||
| get embedding model. | |||||
| :param tenant_id: a string representing the ID of the tenant. | |||||
| :param model_provider_name: | |||||
| :param model_name: | |||||
| :return: | |||||
| """ | |||||
| if model_provider_name is None and model_name is None: | |||||
| default_model = cls.get_default_model(tenant_id, ModelType.EMBEDDINGS) | |||||
| if not default_model: | |||||
| raise LLMBadRequestError(f"Default model is not available. " | |||||
| f"Please configure a Default Embedding Model " | |||||
| f"in the Settings -> Model Provider.") | |||||
| model_provider_name = default_model.provider_name | |||||
| model_name = default_model.model_name | |||||
| # get model provider | |||||
| model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name) | |||||
| if not model_provider: | |||||
| raise ProviderTokenNotInitError(f"Model {model_name} provider credentials is not initialized.") | |||||
| # init embedding model | |||||
| model_class = model_provider.get_model_class(model_type=ModelType.EMBEDDINGS) | |||||
| return model_class( | |||||
| model_provider=model_provider, | |||||
| name=model_name | |||||
| ) | |||||
| @classmethod | |||||
| def get_speech2text_model(cls, | |||||
| tenant_id: str, | |||||
| model_provider_name: Optional[str] = None, | |||||
| model_name: Optional[str] = None) -> Optional[BaseSpeech2Text]: | |||||
| """ | |||||
| get speech to text model. | |||||
| :param tenant_id: a string representing the ID of the tenant. | |||||
| :param model_provider_name: | |||||
| :param model_name: | |||||
| :return: | |||||
| """ | |||||
| if model_provider_name is None and model_name is None: | |||||
| default_model = cls.get_default_model(tenant_id, ModelType.SPEECH_TO_TEXT) | |||||
| if not default_model: | |||||
| raise LLMBadRequestError(f"Default model is not available. " | |||||
| f"Please configure a Default Speech-to-Text Model " | |||||
| f"in the Settings -> Model Provider.") | |||||
| model_provider_name = default_model.provider_name | |||||
| model_name = default_model.model_name | |||||
| # get model provider | |||||
| model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name) | |||||
| if not model_provider: | |||||
| raise ProviderTokenNotInitError(f"Model {model_name} provider credentials is not initialized.") | |||||
| # init speech to text model | |||||
| model_class = model_provider.get_model_class(model_type=ModelType.SPEECH_TO_TEXT) | |||||
| return model_class( | |||||
| model_provider=model_provider, | |||||
| name=model_name | |||||
| ) | |||||
| @classmethod | |||||
| def get_moderation_model(cls, | |||||
| tenant_id: str, | |||||
| model_provider_name: str, | |||||
| model_name: str) -> Optional[BaseProviderModel]: | |||||
| """ | |||||
| get moderation model. | |||||
| :param tenant_id: a string representing the ID of the tenant. | |||||
| :param model_provider_name: | |||||
| :param model_name: | |||||
| :return: | |||||
| """ | |||||
| # get model provider | |||||
| model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name) | |||||
| if not model_provider: | |||||
| raise ProviderTokenNotInitError(f"Model {model_name} provider credentials is not initialized.") | |||||
| # init moderation model | |||||
| model_class = model_provider.get_model_class(model_type=ModelType.MODERATION) | |||||
| return model_class( | |||||
| model_provider=model_provider, | |||||
| name=model_name | |||||
| ) | |||||
| @classmethod | |||||
| def get_default_model(cls, tenant_id: str, model_type: ModelType) -> TenantDefaultModel: | |||||
| """ | |||||
| get default model of model type. | |||||
| :param tenant_id: | |||||
| :param model_type: | |||||
| :return: | |||||
| """ | |||||
| # get default model | |||||
| default_model = db.session.query(TenantDefaultModel) \ | |||||
| .filter( | |||||
| TenantDefaultModel.tenant_id == tenant_id, | |||||
| TenantDefaultModel.model_type == model_type.value | |||||
| ).first() | |||||
| if not default_model: | |||||
| model_provider_rules = ModelProviderFactory.get_provider_rules() | |||||
| for model_provider_name, model_provider_rule in model_provider_rules.items(): | |||||
| model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name) | |||||
| if not model_provider: | |||||
| continue | |||||
| model_list = model_provider.get_supported_model_list(model_type) | |||||
| if model_list: | |||||
| model_info = model_list[0] | |||||
| default_model = TenantDefaultModel( | |||||
| tenant_id=tenant_id, | |||||
| model_type=model_type.value, | |||||
| provider_name=model_provider_name, | |||||
| model_name=model_info['id'] | |||||
| ) | |||||
| db.session.add(default_model) | |||||
| db.session.commit() | |||||
| break | |||||
| return default_model | |||||
| @classmethod | |||||
| def update_default_model(cls, | |||||
| tenant_id: str, | |||||
| model_type: ModelType, | |||||
| provider_name: str, | |||||
| model_name: str) -> TenantDefaultModel: | |||||
| """ | |||||
| update default model of model type. | |||||
| :param tenant_id: | |||||
| :param model_type: | |||||
| :param provider_name: | |||||
| :param model_name: | |||||
| :return: | |||||
| """ | |||||
| model_provider_name = ModelProviderFactory.get_provider_names() | |||||
| if provider_name not in model_provider_name: | |||||
| raise ValueError(f'Invalid provider name: {provider_name}') | |||||
| model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, provider_name) | |||||
| if not model_provider: | |||||
| raise ProviderTokenNotInitError(f"Model {model_name} provider credentials is not initialized.") | |||||
| model_list = model_provider.get_supported_model_list(model_type) | |||||
| model_ids = [model['id'] for model in model_list] | |||||
| if model_name not in model_ids: | |||||
| raise ValueError(f'Invalid model name: {model_name}') | |||||
| # get default model | |||||
| default_model = db.session.query(TenantDefaultModel) \ | |||||
| .filter( | |||||
| TenantDefaultModel.tenant_id == tenant_id, | |||||
| TenantDefaultModel.model_type == model_type.value | |||||
| ).first() | |||||
| if default_model: | |||||
| # update default model | |||||
| default_model.provider_name = provider_name | |||||
| default_model.model_name = model_name | |||||
| db.session.commit() | |||||
| else: | |||||
| # create default model | |||||
| default_model = TenantDefaultModel( | |||||
| tenant_id=tenant_id, | |||||
| model_type=model_type.value, | |||||
| provider_name=provider_name, | |||||
| model_name=model_name, | |||||
| ) | |||||
| db.session.add(default_model) | |||||
| db.session.commit() | |||||
| return default_model |
| from typing import Type | |||||
| from sqlalchemy.exc import IntegrityError | |||||
| from core.model_providers.models.entity.model_params import ModelType | |||||
| from core.model_providers.providers.base import BaseModelProvider | |||||
| from core.model_providers.rules import provider_rules | |||||
| from extensions.ext_database import db | |||||
| from models.provider import TenantPreferredModelProvider, ProviderType, Provider, ProviderQuotaType | |||||
| DEFAULT_MODELS = { | |||||
| ModelType.TEXT_GENERATION.value: { | |||||
| 'provider_name': 'openai', | |||||
| 'model_name': 'gpt-3.5-turbo', | |||||
| }, | |||||
| ModelType.EMBEDDINGS.value: { | |||||
| 'provider_name': 'openai', | |||||
| 'model_name': 'text-embedding-ada-002', | |||||
| }, | |||||
| ModelType.SPEECH_TO_TEXT.value: { | |||||
| 'provider_name': 'openai', | |||||
| 'model_name': 'whisper-1', | |||||
| } | |||||
| } | |||||
| class ModelProviderFactory: | |||||
| @classmethod | |||||
| def get_model_provider_class(cls, provider_name: str) -> Type[BaseModelProvider]: | |||||
| if provider_name == 'openai': | |||||
| from core.model_providers.providers.openai_provider import OpenAIProvider | |||||
| return OpenAIProvider | |||||
| elif provider_name == 'anthropic': | |||||
| from core.model_providers.providers.anthropic_provider import AnthropicProvider | |||||
| return AnthropicProvider | |||||
| elif provider_name == 'minimax': | |||||
| from core.model_providers.providers.minimax_provider import MinimaxProvider | |||||
| return MinimaxProvider | |||||
| elif provider_name == 'spark': | |||||
| from core.model_providers.providers.spark_provider import SparkProvider | |||||
| return SparkProvider | |||||
| elif provider_name == 'tongyi': | |||||
| from core.model_providers.providers.tongyi_provider import TongyiProvider | |||||
| return TongyiProvider | |||||
| elif provider_name == 'wenxin': | |||||
| from core.model_providers.providers.wenxin_provider import WenxinProvider | |||||
| return WenxinProvider | |||||
| elif provider_name == 'chatglm': | |||||
| from core.model_providers.providers.chatglm_provider import ChatGLMProvider | |||||
| return ChatGLMProvider | |||||
| elif provider_name == 'azure_openai': | |||||
| from core.model_providers.providers.azure_openai_provider import AzureOpenAIProvider | |||||
| return AzureOpenAIProvider | |||||
| elif provider_name == 'replicate': | |||||
| from core.model_providers.providers.replicate_provider import ReplicateProvider | |||||
| return ReplicateProvider | |||||
| elif provider_name == 'huggingface_hub': | |||||
| from core.model_providers.providers.huggingface_hub_provider import HuggingfaceHubProvider | |||||
| return HuggingfaceHubProvider | |||||
| else: | |||||
| raise NotImplementedError | |||||
| @classmethod | |||||
| def get_provider_names(cls): | |||||
| """ | |||||
| Returns a list of provider names. | |||||
| """ | |||||
| return list(provider_rules.keys()) | |||||
| @classmethod | |||||
| def get_provider_rules(cls): | |||||
| """ | |||||
| Returns a list of provider rules. | |||||
| :return: | |||||
| """ | |||||
| return provider_rules | |||||
| @classmethod | |||||
| def get_provider_rule(cls, provider_name: str): | |||||
| """ | |||||
| Returns provider rule. | |||||
| """ | |||||
| return provider_rules[provider_name] | |||||
| @classmethod | |||||
| def get_preferred_model_provider(cls, tenant_id: str, model_provider_name: str): | |||||
| """ | |||||
| get preferred model provider. | |||||
| :param tenant_id: a string representing the ID of the tenant. | |||||
| :param model_provider_name: | |||||
| :return: | |||||
| """ | |||||
| # get preferred provider | |||||
| preferred_provider = cls._get_preferred_provider(tenant_id, model_provider_name) | |||||
| if not preferred_provider or not preferred_provider.is_valid: | |||||
| return None | |||||
| # init model provider | |||||
| model_provider_class = ModelProviderFactory.get_model_provider_class(model_provider_name) | |||||
| return model_provider_class(provider=preferred_provider) | |||||
| @classmethod | |||||
| def get_preferred_type_by_preferred_model_provider(cls, | |||||
| tenant_id: str, | |||||
| model_provider_name: str, | |||||
| preferred_model_provider: TenantPreferredModelProvider): | |||||
| """ | |||||
| get preferred provider type by preferred model provider. | |||||
| :param model_provider_name: | |||||
| :param preferred_model_provider: | |||||
| :return: | |||||
| """ | |||||
| if not preferred_model_provider: | |||||
| model_provider_rules = ModelProviderFactory.get_provider_rule(model_provider_name) | |||||
| support_provider_types = model_provider_rules['support_provider_types'] | |||||
| if ProviderType.CUSTOM.value in support_provider_types: | |||||
| custom_provider = db.session.query(Provider) \ | |||||
| .filter( | |||||
| Provider.tenant_id == tenant_id, | |||||
| Provider.provider_name == model_provider_name, | |||||
| Provider.provider_type == ProviderType.CUSTOM.value, | |||||
| Provider.is_valid == True | |||||
| ).first() | |||||
| if custom_provider: | |||||
| return ProviderType.CUSTOM.value | |||||
| model_provider = cls.get_model_provider_class(model_provider_name) | |||||
| if ProviderType.SYSTEM.value in support_provider_types \ | |||||
| and model_provider.is_provider_type_system_supported(): | |||||
| return ProviderType.SYSTEM.value | |||||
| elif ProviderType.CUSTOM.value in support_provider_types: | |||||
| return ProviderType.CUSTOM.value | |||||
| else: | |||||
| return preferred_model_provider.preferred_provider_type | |||||
| @classmethod | |||||
| def _get_preferred_provider(cls, tenant_id: str, model_provider_name: str): | |||||
| """ | |||||
| get preferred provider of tenant. | |||||
| :param tenant_id: | |||||
| :param model_provider_name: | |||||
| :return: | |||||
| """ | |||||
| # get preferred provider type | |||||
| preferred_provider_type = cls._get_preferred_provider_type(tenant_id, model_provider_name) | |||||
| # get providers by preferred provider type | |||||
| providers = db.session.query(Provider) \ | |||||
| .filter( | |||||
| Provider.tenant_id == tenant_id, | |||||
| Provider.provider_name == model_provider_name, | |||||
| Provider.provider_type == preferred_provider_type | |||||
| ).all() | |||||
| no_system_provider = False | |||||
| if preferred_provider_type == ProviderType.SYSTEM.value: | |||||
| quota_type_to_provider_dict = {} | |||||
| for provider in providers: | |||||
| quota_type_to_provider_dict[provider.quota_type] = provider | |||||
| model_provider_rules = ModelProviderFactory.get_provider_rule(model_provider_name) | |||||
| for quota_type_enum in ProviderQuotaType: | |||||
| quota_type = quota_type_enum.value | |||||
| if quota_type in model_provider_rules['system_config']['supported_quota_types'] \ | |||||
| and quota_type in quota_type_to_provider_dict.keys(): | |||||
| provider = quota_type_to_provider_dict[quota_type] | |||||
| if provider.is_valid and provider.quota_limit > provider.quota_used: | |||||
| return provider | |||||
| no_system_provider = True | |||||
| if no_system_provider: | |||||
| providers = db.session.query(Provider) \ | |||||
| .filter( | |||||
| Provider.tenant_id == tenant_id, | |||||
| Provider.provider_name == model_provider_name, | |||||
| Provider.provider_type == ProviderType.CUSTOM.value | |||||
| ).all() | |||||
| if preferred_provider_type == ProviderType.CUSTOM.value or no_system_provider: | |||||
| if providers: | |||||
| return providers[0] | |||||
| else: | |||||
| try: | |||||
| provider = Provider( | |||||
| tenant_id=tenant_id, | |||||
| provider_name=model_provider_name, | |||||
| provider_type=ProviderType.CUSTOM.value, | |||||
| is_valid=False | |||||
| ) | |||||
| db.session.add(provider) | |||||
| db.session.commit() | |||||
| except IntegrityError: | |||||
| db.session.rollback() | |||||
| provider = db.session.query(Provider) \ | |||||
| .filter( | |||||
| Provider.tenant_id == tenant_id, | |||||
| Provider.provider_name == model_provider_name, | |||||
| Provider.provider_type == ProviderType.CUSTOM.value | |||||
| ).first() | |||||
| return provider | |||||
| return None | |||||
| @classmethod | |||||
| def _get_preferred_provider_type(cls, tenant_id: str, model_provider_name: str): | |||||
| """ | |||||
| get preferred provider type of tenant. | |||||
| :param tenant_id: | |||||
| :param model_provider_name: | |||||
| :return: | |||||
| """ | |||||
| preferred_model_provider = db.session.query(TenantPreferredModelProvider) \ | |||||
| .filter( | |||||
| TenantPreferredModelProvider.tenant_id == tenant_id, | |||||
| TenantPreferredModelProvider.provider_name == model_provider_name | |||||
| ).first() | |||||
| return cls.get_preferred_type_by_preferred_model_provider(tenant_id, model_provider_name, preferred_model_provider) |
| from abc import ABC | |||||
| from typing import Any | |||||
| from core.model_providers.providers.base import BaseModelProvider | |||||
| class BaseProviderModel(ABC): | |||||
| _client: Any | |||||
| _model_provider: BaseModelProvider | |||||
| def __init__(self, model_provider: BaseModelProvider, client: Any): | |||||
| self._model_provider = model_provider | |||||
| self._client = client | |||||
| @property | |||||
| def client(self): | |||||
| return self._client | |||||
| @property | |||||
| def model_provider(self): | |||||
| return self._model_provider | |||||
| import decimal | |||||
| import logging | |||||
| import openai | |||||
| import tiktoken | |||||
| from langchain.embeddings import OpenAIEmbeddings | |||||
| from core.model_providers.error import LLMBadRequestError, LLMAuthorizationError, LLMRateLimitError, \ | |||||
| LLMAPIUnavailableError, LLMAPIConnectionError | |||||
| from core.model_providers.models.embedding.base import BaseEmbedding | |||||
| from core.model_providers.providers.base import BaseModelProvider | |||||
| AZURE_OPENAI_API_VERSION = '2023-07-01-preview' | |||||
| class AzureOpenAIEmbedding(BaseEmbedding): | |||||
| def __init__(self, model_provider: BaseModelProvider, name: str): | |||||
| self.credentials = model_provider.get_model_credentials( | |||||
| model_name=name, | |||||
| model_type=self.type | |||||
| ) | |||||
| client = OpenAIEmbeddings( | |||||
| deployment=name, | |||||
| openai_api_type='azure', | |||||
| openai_api_version=AZURE_OPENAI_API_VERSION, | |||||
| chunk_size=16, | |||||
| max_retries=1, | |||||
| **self.credentials | |||||
| ) | |||||
| super().__init__(model_provider, client, name) | |||||
| def get_num_tokens(self, text: str) -> int: | |||||
| """ | |||||
| get num tokens of text. | |||||
| :param text: | |||||
| :return: | |||||
| """ | |||||
| if len(text) == 0: | |||||
| return 0 | |||||
| enc = tiktoken.encoding_for_model(self.credentials.get('base_model_name')) | |||||
| tokenized_text = enc.encode(text) | |||||
| # calculate the number of tokens in the encoded text | |||||
| return len(tokenized_text) | |||||
| def get_token_price(self, tokens: int): | |||||
| tokens_per_1k = (decimal.Decimal(tokens) / 1000).quantize(decimal.Decimal('0.001'), | |||||
| rounding=decimal.ROUND_HALF_UP) | |||||
| total_price = tokens_per_1k * decimal.Decimal('0.0001') | |||||
| return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP) | |||||
| def get_currency(self): | |||||
| return 'USD' | |||||
| def handle_exceptions(self, ex: Exception) -> Exception: | |||||
| if isinstance(ex, openai.error.InvalidRequestError): | |||||
| logging.warning("Invalid request to Azure OpenAI API.") | |||||
| return LLMBadRequestError(str(ex)) | |||||
| elif isinstance(ex, openai.error.APIConnectionError): | |||||
| logging.warning("Failed to connect to Azure OpenAI API.") | |||||
| return LLMAPIConnectionError(ex.__class__.__name__ + ":" + str(ex)) | |||||
| elif isinstance(ex, (openai.error.APIError, openai.error.ServiceUnavailableError, openai.error.Timeout)): | |||||
| logging.warning("Azure OpenAI service unavailable.") | |||||
| return LLMAPIUnavailableError(ex.__class__.__name__ + ":" + str(ex)) | |||||
| elif isinstance(ex, openai.error.RateLimitError): | |||||
| return LLMRateLimitError('Azure ' + str(ex)) | |||||
| elif isinstance(ex, openai.error.AuthenticationError): | |||||
| raise LLMAuthorizationError('Azure ' + str(ex)) | |||||
| elif isinstance(ex, openai.error.OpenAIError): | |||||
| return LLMBadRequestError('Azure ' + ex.__class__.__name__ + ":" + str(ex)) | |||||
| else: | |||||
| return ex |
| from abc import abstractmethod | |||||
| from typing import Any | |||||
| import tiktoken | |||||
| from langchain.schema.language_model import _get_token_ids_default_method | |||||
| from core.model_providers.models.base import BaseProviderModel | |||||
| from core.model_providers.models.entity.model_params import ModelType | |||||
| from core.model_providers.providers.base import BaseModelProvider | |||||
| class BaseEmbedding(BaseProviderModel): | |||||
| name: str | |||||
| type: ModelType = ModelType.EMBEDDINGS | |||||
| def __init__(self, model_provider: BaseModelProvider, client: Any, name: str): | |||||
| super().__init__(model_provider, client) | |||||
| self.name = name | |||||
| def get_num_tokens(self, text: str) -> int: | |||||
| """ | |||||
| get num tokens of text. | |||||
| :param text: | |||||
| :return: | |||||
| """ | |||||
| if len(text) == 0: | |||||
| return 0 | |||||
| return len(_get_token_ids_default_method(text)) | |||||
| def get_token_price(self, tokens: int): | |||||
| return 0 | |||||
| def get_currency(self): | |||||
| return 'USD' | |||||
| @abstractmethod | |||||
| def handle_exceptions(self, ex: Exception) -> Exception: | |||||
| raise NotImplementedError |
| import decimal | |||||
| import logging | |||||
| from langchain.embeddings import MiniMaxEmbeddings | |||||
| from core.model_providers.error import LLMBadRequestError | |||||
| from core.model_providers.models.embedding.base import BaseEmbedding | |||||
| from core.model_providers.providers.base import BaseModelProvider | |||||
| class MinimaxEmbedding(BaseEmbedding): | |||||
| def __init__(self, model_provider: BaseModelProvider, name: str): | |||||
| credentials = model_provider.get_model_credentials( | |||||
| model_name=name, | |||||
| model_type=self.type | |||||
| ) | |||||
| client = MiniMaxEmbeddings( | |||||
| model=name, | |||||
| **credentials | |||||
| ) | |||||
| super().__init__(model_provider, client, name) | |||||
| def get_token_price(self, tokens: int): | |||||
| return decimal.Decimal('0') | |||||
| def get_currency(self): | |||||
| return 'RMB' | |||||
| def handle_exceptions(self, ex: Exception) -> Exception: | |||||
| if isinstance(ex, ValueError): | |||||
| return LLMBadRequestError(f"Minimax: {str(ex)}") | |||||
| else: | |||||
| return ex |
| import decimal | |||||
| import logging | |||||
| import openai | |||||
| import tiktoken | |||||
| from langchain.embeddings import OpenAIEmbeddings | |||||
| from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, \ | |||||
| LLMRateLimitError, LLMAuthorizationError | |||||
| from core.model_providers.models.embedding.base import BaseEmbedding | |||||
| from core.model_providers.providers.base import BaseModelProvider | |||||
| class OpenAIEmbedding(BaseEmbedding): | |||||
| def __init__(self, model_provider: BaseModelProvider, name: str): | |||||
| credentials = model_provider.get_model_credentials( | |||||
| model_name=name, | |||||
| model_type=self.type | |||||
| ) | |||||
| client = OpenAIEmbeddings( | |||||
| max_retries=1, | |||||
| **credentials | |||||
| ) | |||||
| super().__init__(model_provider, client, name) | |||||
| def get_num_tokens(self, text: str) -> int: | |||||
| """ | |||||
| get num tokens of text. | |||||
| :param text: | |||||
| :return: | |||||
| """ | |||||
| if len(text) == 0: | |||||
| return 0 | |||||
| enc = tiktoken.encoding_for_model(self.name) | |||||
| tokenized_text = enc.encode(text) | |||||
| # calculate the number of tokens in the encoded text | |||||
| return len(tokenized_text) | |||||
| def get_token_price(self, tokens: int): | |||||
| tokens_per_1k = (decimal.Decimal(tokens) / 1000).quantize(decimal.Decimal('0.001'), | |||||
| rounding=decimal.ROUND_HALF_UP) | |||||
| total_price = tokens_per_1k * decimal.Decimal('0.0001') | |||||
| return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP) | |||||
| def get_currency(self): | |||||
| return 'USD' | |||||
| def handle_exceptions(self, ex: Exception) -> Exception: | |||||
| if isinstance(ex, openai.error.InvalidRequestError): | |||||
| logging.warning("Invalid request to OpenAI API.") | |||||
| return LLMBadRequestError(str(ex)) | |||||
| elif isinstance(ex, openai.error.APIConnectionError): | |||||
| logging.warning("Failed to connect to OpenAI API.") | |||||
| return LLMAPIConnectionError(ex.__class__.__name__ + ":" + str(ex)) | |||||
| elif isinstance(ex, (openai.error.APIError, openai.error.ServiceUnavailableError, openai.error.Timeout)): | |||||
| logging.warning("OpenAI service unavailable.") | |||||
| return LLMAPIUnavailableError(ex.__class__.__name__ + ":" + str(ex)) | |||||
| elif isinstance(ex, openai.error.RateLimitError): | |||||
| return LLMRateLimitError(str(ex)) | |||||
| elif isinstance(ex, openai.error.AuthenticationError): | |||||
| raise LLMAuthorizationError(str(ex)) | |||||
| elif isinstance(ex, openai.error.OpenAIError): | |||||
| return LLMBadRequestError(ex.__class__.__name__ + ":" + str(ex)) | |||||
| else: | |||||
| return ex |
| import decimal | |||||
| from replicate.exceptions import ModelError, ReplicateError | |||||
| from core.model_providers.error import LLMBadRequestError | |||||
| from core.model_providers.providers.base import BaseModelProvider | |||||
| from core.third_party.langchain.embeddings.replicate_embedding import ReplicateEmbeddings | |||||
| from core.model_providers.models.embedding.base import BaseEmbedding | |||||
| class ReplicateEmbedding(BaseEmbedding): | |||||
| def __init__(self, model_provider: BaseModelProvider, name: str): | |||||
| credentials = model_provider.get_model_credentials( | |||||
| model_name=name, | |||||
| model_type=self.type | |||||
| ) | |||||
| client = ReplicateEmbeddings( | |||||
| model=name + ':' + credentials.get('model_version'), | |||||
| replicate_api_token=credentials.get('replicate_api_token') | |||||
| ) | |||||
| super().__init__(model_provider, client, name) | |||||
| def get_token_price(self, tokens: int): | |||||
| # replicate only pay for prediction seconds | |||||
| return decimal.Decimal('0') | |||||
| def get_currency(self): | |||||
| return 'USD' | |||||
| def handle_exceptions(self, ex: Exception) -> Exception: | |||||
| if isinstance(ex, (ModelError, ReplicateError)): | |||||
| return LLMBadRequestError(f"Replicate: {str(ex)}") | |||||
| else: | |||||
| return ex |
| import enum | |||||
| from langchain.schema import HumanMessage, AIMessage, SystemMessage, BaseMessage | |||||
| from pydantic import BaseModel | |||||
| class LLMRunResult(BaseModel): | |||||
| content: str | |||||
| prompt_tokens: int | |||||
| completion_tokens: int | |||||
| class MessageType(enum.Enum): | |||||
| HUMAN = 'human' | |||||
| ASSISTANT = 'assistant' | |||||
| SYSTEM = 'system' | |||||
| class PromptMessage(BaseModel): | |||||
| type: MessageType = MessageType.HUMAN | |||||
| content: str = '' | |||||
| def to_lc_messages(messages: list[PromptMessage]): | |||||
| lc_messages = [] | |||||
| for message in messages: | |||||
| if message.type == MessageType.HUMAN: | |||||
| lc_messages.append(HumanMessage(content=message.content)) | |||||
| elif message.type == MessageType.ASSISTANT: | |||||
| lc_messages.append(AIMessage(content=message.content)) | |||||
| elif message.type == MessageType.SYSTEM: | |||||
| lc_messages.append(SystemMessage(content=message.content)) | |||||
| return lc_messages | |||||
| def to_prompt_messages(messages: list[BaseMessage]): | |||||
| prompt_messages = [] | |||||
| for message in messages: | |||||
| if isinstance(message, HumanMessage): | |||||
| prompt_messages.append(PromptMessage(content=message.content, type=MessageType.HUMAN)) | |||||
| elif isinstance(message, AIMessage): | |||||
| prompt_messages.append(PromptMessage(content=message.content, type=MessageType.ASSISTANT)) | |||||
| elif isinstance(message, SystemMessage): | |||||
| prompt_messages.append(PromptMessage(content=message.content, type=MessageType.SYSTEM)) | |||||
| return prompt_messages | |||||
| def str_to_prompt_messages(texts: list[str]): | |||||
| prompt_messages = [] | |||||
| for text in texts: | |||||
| prompt_messages.append(PromptMessage(content=text)) | |||||
| return prompt_messages |
| import enum | |||||
| from typing import Optional, TypeVar, Generic | |||||
| from langchain.load.serializable import Serializable | |||||
| from pydantic import BaseModel | |||||
| class ModelMode(enum.Enum): | |||||
| COMPLETION = 'completion' | |||||
| CHAT = 'chat' | |||||
| class ModelType(enum.Enum): | |||||
| TEXT_GENERATION = 'text-generation' | |||||
| EMBEDDINGS = 'embeddings' | |||||
| SPEECH_TO_TEXT = 'speech2text' | |||||
| IMAGE = 'image' | |||||
| VIDEO = 'video' | |||||
| MODERATION = 'moderation' | |||||
| @staticmethod | |||||
| def value_of(value): | |||||
| for member in ModelType: | |||||
| if member.value == value: | |||||
| return member | |||||
| raise ValueError(f"No matching enum found for value '{value}'") | |||||
| class ModelKwargs(BaseModel): | |||||
| max_tokens: Optional[int] | |||||
| temperature: Optional[float] | |||||
| top_p: Optional[float] | |||||
| presence_penalty: Optional[float] | |||||
| frequency_penalty: Optional[float] | |||||
| class KwargRuleType(enum.Enum): | |||||
| STRING = 'string' | |||||
| INTEGER = 'integer' | |||||
| FLOAT = 'float' | |||||
| T = TypeVar('T') | |||||
| class KwargRule(Generic[T], BaseModel): | |||||
| enabled: bool = True | |||||
| min: Optional[T] = None | |||||
| max: Optional[T] = None | |||||
| default: Optional[T] = None | |||||
| alias: Optional[str] = None | |||||
| class ModelKwargsRules(BaseModel): | |||||
| max_tokens: KwargRule = KwargRule[int](enabled=False) | |||||
| temperature: KwargRule = KwargRule[float](enabled=False) | |||||
| top_p: KwargRule = KwargRule[float](enabled=False) | |||||
| presence_penalty: KwargRule = KwargRule[float](enabled=False) | |||||
| frequency_penalty: KwargRule = KwargRule[float](enabled=False) |
| from enum import Enum | |||||
| class ProviderQuotaUnit(Enum): | |||||
| TIMES = 'times' | |||||
| TOKENS = 'tokens' | |||||
| class ModelFeature(Enum): | |||||
| AGENT_THOUGHT = 'agent_thought' |
| import decimal | |||||
| import logging | |||||
| from functools import wraps | |||||
| from typing import List, Optional, Any | |||||
| import anthropic | |||||
| from langchain.callbacks.manager import Callbacks | |||||
| from langchain.chat_models import ChatAnthropic | |||||
| from langchain.schema import LLMResult | |||||
| from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, \ | |||||
| LLMRateLimitError, LLMAuthorizationError | |||||
| from core.model_providers.models.llm.base import BaseLLM | |||||
| from core.model_providers.models.entity.message import PromptMessage, MessageType | |||||
| from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs | |||||
| class AnthropicModel(BaseLLM): | |||||
| model_mode: ModelMode = ModelMode.CHAT | |||||
| def _init_client(self) -> Any: | |||||
| provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs) | |||||
| return ChatAnthropic( | |||||
| model=self.name, | |||||
| streaming=self.streaming, | |||||
| callbacks=self.callbacks, | |||||
| default_request_timeout=60, | |||||
| **self.credentials, | |||||
| **provider_model_kwargs | |||||
| ) | |||||
| def _run(self, messages: List[PromptMessage], | |||||
| stop: Optional[List[str]] = None, | |||||
| callbacks: Callbacks = None, | |||||
| **kwargs) -> LLMResult: | |||||
| """ | |||||
| run predict by prompt messages and stop words. | |||||
| :param messages: | |||||
| :param stop: | |||||
| :param callbacks: | |||||
| :return: | |||||
| """ | |||||
| prompts = self._get_prompt_from_messages(messages) | |||||
| return self._client.generate([prompts], stop, callbacks) | |||||
| def get_num_tokens(self, messages: List[PromptMessage]) -> int: | |||||
| """ | |||||
| get num tokens of prompt messages. | |||||
| :param messages: | |||||
| :return: | |||||
| """ | |||||
| prompts = self._get_prompt_from_messages(messages) | |||||
| return max(self._client.get_num_tokens_from_messages(prompts) - len(prompts), 0) | |||||
| def get_token_price(self, tokens: int, message_type: MessageType): | |||||
| model_unit_prices = { | |||||
| 'claude-instant-1': { | |||||
| 'prompt': decimal.Decimal('1.63'), | |||||
| 'completion': decimal.Decimal('5.51'), | |||||
| }, | |||||
| 'claude-2': { | |||||
| 'prompt': decimal.Decimal('11.02'), | |||||
| 'completion': decimal.Decimal('32.68'), | |||||
| }, | |||||
| } | |||||
| if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM: | |||||
| unit_price = model_unit_prices[self.name]['prompt'] | |||||
| else: | |||||
| unit_price = model_unit_prices[self.name]['completion'] | |||||
| tokens_per_1m = (decimal.Decimal(tokens) / 1000000).quantize(decimal.Decimal('0.000001'), | |||||
| rounding=decimal.ROUND_HALF_UP) | |||||
| total_price = tokens_per_1m * unit_price | |||||
| return total_price.quantize(decimal.Decimal('0.00000001'), rounding=decimal.ROUND_HALF_UP) | |||||
| def get_currency(self): | |||||
| return 'USD' | |||||
| def _set_model_kwargs(self, model_kwargs: ModelKwargs): | |||||
| provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs) | |||||
| for k, v in provider_model_kwargs.items(): | |||||
| if hasattr(self.client, k): | |||||
| setattr(self.client, k, v) | |||||
| def handle_exceptions(self, ex: Exception) -> Exception: | |||||
| if isinstance(ex, anthropic.APIConnectionError): | |||||
| logging.warning("Failed to connect to Anthropic API.") | |||||
| return LLMAPIConnectionError(f"Anthropic: The server could not be reached, cause: {ex.__cause__}") | |||||
| elif isinstance(ex, anthropic.RateLimitError): | |||||
| return LLMRateLimitError("Anthropic: A 429 status code was received; we should back off a bit.") | |||||
| elif isinstance(ex, anthropic.AuthenticationError): | |||||
| return LLMAuthorizationError(f"Anthropic: {ex.message}") | |||||
| elif isinstance(ex, anthropic.BadRequestError): | |||||
| return LLMBadRequestError(f"Anthropic: {ex.message}") | |||||
| elif isinstance(ex, anthropic.APIStatusError): | |||||
| return LLMAPIUnavailableError(f"Anthropic: code: {ex.status_code}, cause: {ex.message}") | |||||
| else: | |||||
| return ex | |||||
| @classmethod | |||||
| def support_streaming(cls): | |||||
| return True | |||||
| import decimal | |||||
| import logging | |||||
| from functools import wraps | |||||
| from typing import List, Optional, Any | |||||
| import openai | |||||
| from langchain.callbacks.manager import Callbacks | |||||
| from langchain.schema import LLMResult | |||||
| from core.model_providers.providers.base import BaseModelProvider | |||||
| from core.third_party.langchain.llms.azure_chat_open_ai import EnhanceAzureChatOpenAI | |||||
| from core.third_party.langchain.llms.azure_open_ai import EnhanceAzureOpenAI | |||||
| from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, \ | |||||
| LLMRateLimitError, LLMAuthorizationError | |||||
| from core.model_providers.models.llm.base import BaseLLM | |||||
| from core.model_providers.models.entity.message import PromptMessage, MessageType | |||||
| from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs | |||||
| AZURE_OPENAI_API_VERSION = '2023-07-01-preview' | |||||
| class AzureOpenAIModel(BaseLLM): | |||||
| def __init__(self, model_provider: BaseModelProvider, | |||||
| name: str, | |||||
| model_kwargs: ModelKwargs, | |||||
| streaming: bool = False, | |||||
| callbacks: Callbacks = None): | |||||
| if name == 'text-davinci-003': | |||||
| self.model_mode = ModelMode.COMPLETION | |||||
| else: | |||||
| self.model_mode = ModelMode.CHAT | |||||
| super().__init__(model_provider, name, model_kwargs, streaming, callbacks) | |||||
| def _init_client(self) -> Any: | |||||
| provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs) | |||||
| if self.name == 'text-davinci-003': | |||||
| client = EnhanceAzureOpenAI( | |||||
| deployment_name=self.name, | |||||
| streaming=self.streaming, | |||||
| request_timeout=60, | |||||
| openai_api_type='azure', | |||||
| openai_api_version=AZURE_OPENAI_API_VERSION, | |||||
| openai_api_key=self.credentials.get('openai_api_key'), | |||||
| openai_api_base=self.credentials.get('openai_api_base'), | |||||
| callbacks=self.callbacks, | |||||
| **provider_model_kwargs | |||||
| ) | |||||
| else: | |||||
| extra_model_kwargs = { | |||||
| 'top_p': provider_model_kwargs.get('top_p'), | |||||
| 'frequency_penalty': provider_model_kwargs.get('frequency_penalty'), | |||||
| 'presence_penalty': provider_model_kwargs.get('presence_penalty'), | |||||
| } | |||||
| client = EnhanceAzureChatOpenAI( | |||||
| deployment_name=self.name, | |||||
| temperature=provider_model_kwargs.get('temperature'), | |||||
| max_tokens=provider_model_kwargs.get('max_tokens'), | |||||
| model_kwargs=extra_model_kwargs, | |||||
| streaming=self.streaming, | |||||
| request_timeout=60, | |||||
| openai_api_type='azure', | |||||
| openai_api_version=AZURE_OPENAI_API_VERSION, | |||||
| openai_api_key=self.credentials.get('openai_api_key'), | |||||
| openai_api_base=self.credentials.get('openai_api_base'), | |||||
| callbacks=self.callbacks, | |||||
| ) | |||||
| return client | |||||
| def _run(self, messages: List[PromptMessage], | |||||
| stop: Optional[List[str]] = None, | |||||
| callbacks: Callbacks = None, | |||||
| **kwargs) -> LLMResult: | |||||
| """ | |||||
| run predict by prompt messages and stop words. | |||||
| :param messages: | |||||
| :param stop: | |||||
| :param callbacks: | |||||
| :return: | |||||
| """ | |||||
| prompts = self._get_prompt_from_messages(messages) | |||||
| return self._client.generate([prompts], stop, callbacks) | |||||
| def get_num_tokens(self, messages: List[PromptMessage]) -> int: | |||||
| """ | |||||
| get num tokens of prompt messages. | |||||
| :param messages: | |||||
| :return: | |||||
| """ | |||||
| prompts = self._get_prompt_from_messages(messages) | |||||
| if isinstance(prompts, str): | |||||
| return self._client.get_num_tokens(prompts) | |||||
| else: | |||||
| return max(self._client.get_num_tokens_from_messages(prompts) - len(prompts), 0) | |||||
| def get_token_price(self, tokens: int, message_type: MessageType): | |||||
| model_unit_prices = { | |||||
| 'gpt-4': { | |||||
| 'prompt': decimal.Decimal('0.03'), | |||||
| 'completion': decimal.Decimal('0.06'), | |||||
| }, | |||||
| 'gpt-4-32k': { | |||||
| 'prompt': decimal.Decimal('0.06'), | |||||
| 'completion': decimal.Decimal('0.12') | |||||
| }, | |||||
| 'gpt-35-turbo': { | |||||
| 'prompt': decimal.Decimal('0.0015'), | |||||
| 'completion': decimal.Decimal('0.002') | |||||
| }, | |||||
| 'gpt-35-turbo-16k': { | |||||
| 'prompt': decimal.Decimal('0.003'), | |||||
| 'completion': decimal.Decimal('0.004') | |||||
| }, | |||||
| 'text-davinci-003': { | |||||
| 'prompt': decimal.Decimal('0.02'), | |||||
| 'completion': decimal.Decimal('0.02') | |||||
| }, | |||||
| } | |||||
| base_model_name = self.credentials.get("base_model_name") | |||||
| if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM: | |||||
| unit_price = model_unit_prices[base_model_name]['prompt'] | |||||
| else: | |||||
| unit_price = model_unit_prices[base_model_name]['completion'] | |||||
| tokens_per_1k = (decimal.Decimal(tokens) / 1000).quantize(decimal.Decimal('0.001'), | |||||
| rounding=decimal.ROUND_HALF_UP) | |||||
| total_price = tokens_per_1k * unit_price | |||||
| return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP) | |||||
| def get_currency(self): | |||||
| return 'USD' | |||||
| def _set_model_kwargs(self, model_kwargs: ModelKwargs): | |||||
| provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs) | |||||
| if self.name == 'text-davinci-003': | |||||
| for k, v in provider_model_kwargs.items(): | |||||
| if hasattr(self.client, k): | |||||
| setattr(self.client, k, v) | |||||
| else: | |||||
| extra_model_kwargs = { | |||||
| 'top_p': provider_model_kwargs.get('top_p'), | |||||
| 'frequency_penalty': provider_model_kwargs.get('frequency_penalty'), | |||||
| 'presence_penalty': provider_model_kwargs.get('presence_penalty'), | |||||
| } | |||||
| self.client.temperature = provider_model_kwargs.get('temperature') | |||||
| self.client.max_tokens = provider_model_kwargs.get('max_tokens') | |||||
| self.client.model_kwargs = extra_model_kwargs | |||||
| def handle_exceptions(self, ex: Exception) -> Exception: | |||||
| if isinstance(ex, openai.error.InvalidRequestError): | |||||
| logging.warning("Invalid request to Azure OpenAI API.") | |||||
| return LLMBadRequestError(str(ex)) | |||||
| elif isinstance(ex, openai.error.APIConnectionError): | |||||
| logging.warning("Failed to connect to Azure OpenAI API.") | |||||
| return LLMAPIConnectionError(ex.__class__.__name__ + ":" + str(ex)) | |||||
| elif isinstance(ex, (openai.error.APIError, openai.error.ServiceUnavailableError, openai.error.Timeout)): | |||||
| logging.warning("Azure OpenAI service unavailable.") | |||||
| return LLMAPIUnavailableError(ex.__class__.__name__ + ":" + str(ex)) | |||||
| elif isinstance(ex, openai.error.RateLimitError): | |||||
| return LLMRateLimitError('Azure ' + str(ex)) | |||||
| elif isinstance(ex, openai.error.AuthenticationError): | |||||
| raise LLMAuthorizationError('Azure ' + str(ex)) | |||||
| elif isinstance(ex, openai.error.OpenAIError): | |||||
| return LLMBadRequestError('Azure ' + ex.__class__.__name__ + ":" + str(ex)) | |||||
| else: | |||||
| return ex | |||||
| @classmethod | |||||
| def support_streaming(cls): | |||||
| return True |
| from abc import abstractmethod | |||||
| from typing import List, Optional, Any, Union | |||||
| from langchain.callbacks.manager import Callbacks | |||||
| from langchain.schema import LLMResult, SystemMessage, AIMessage, HumanMessage, BaseMessage, ChatGeneration | |||||
| from core.callback_handler.std_out_callback_handler import DifyStreamingStdOutCallbackHandler, DifyStdOutCallbackHandler | |||||
| from core.model_providers.models.base import BaseProviderModel | |||||
| from core.model_providers.models.entity.message import PromptMessage, MessageType, LLMRunResult | |||||
| from core.model_providers.models.entity.model_params import ModelType, ModelKwargs, ModelMode, ModelKwargsRules | |||||
| from core.model_providers.providers.base import BaseModelProvider | |||||
| from core.third_party.langchain.llms.fake import FakeLLM | |||||
| class BaseLLM(BaseProviderModel): | |||||
| model_mode: ModelMode = ModelMode.COMPLETION | |||||
| name: str | |||||
| model_kwargs: ModelKwargs | |||||
| credentials: dict | |||||
| streaming: bool = False | |||||
| type: ModelType = ModelType.TEXT_GENERATION | |||||
| deduct_quota: bool = True | |||||
| def __init__(self, model_provider: BaseModelProvider, | |||||
| name: str, | |||||
| model_kwargs: ModelKwargs, | |||||
| streaming: bool = False, | |||||
| callbacks: Callbacks = None): | |||||
| self.name = name | |||||
| self.model_rules = model_provider.get_model_parameter_rules(name, self.type) | |||||
| self.model_kwargs = model_kwargs if model_kwargs else ModelKwargs( | |||||
| max_tokens=None, | |||||
| temperature=None, | |||||
| top_p=None, | |||||
| presence_penalty=None, | |||||
| frequency_penalty=None | |||||
| ) | |||||
| self.credentials = model_provider.get_model_credentials( | |||||
| model_name=name, | |||||
| model_type=self.type | |||||
| ) | |||||
| self.streaming = streaming | |||||
| if streaming: | |||||
| default_callback = DifyStreamingStdOutCallbackHandler() | |||||
| else: | |||||
| default_callback = DifyStdOutCallbackHandler() | |||||
| if not callbacks: | |||||
| callbacks = [default_callback] | |||||
| else: | |||||
| callbacks.append(default_callback) | |||||
| self.callbacks = callbacks | |||||
| client = self._init_client() | |||||
| super().__init__(model_provider, client) | |||||
| @abstractmethod | |||||
| def _init_client(self) -> Any: | |||||
| raise NotImplementedError | |||||
| def run(self, messages: List[PromptMessage], | |||||
| stop: Optional[List[str]] = None, | |||||
| callbacks: Callbacks = None, | |||||
| **kwargs) -> LLMRunResult: | |||||
| """ | |||||
| run predict by prompt messages and stop words. | |||||
| :param messages: | |||||
| :param stop: | |||||
| :param callbacks: | |||||
| :return: | |||||
| """ | |||||
| if self.deduct_quota: | |||||
| self.model_provider.check_quota_over_limit() | |||||
| if not callbacks: | |||||
| callbacks = self.callbacks | |||||
| else: | |||||
| callbacks.extend(self.callbacks) | |||||
| if 'fake_response' in kwargs and kwargs['fake_response']: | |||||
| prompts = self._get_prompt_from_messages(messages, ModelMode.CHAT) | |||||
| fake_llm = FakeLLM( | |||||
| response=kwargs['fake_response'], | |||||
| num_token_func=self.get_num_tokens, | |||||
| streaming=self.streaming, | |||||
| callbacks=callbacks | |||||
| ) | |||||
| result = fake_llm.generate([prompts]) | |||||
| else: | |||||
| try: | |||||
| result = self._run( | |||||
| messages=messages, | |||||
| stop=stop, | |||||
| callbacks=callbacks if not (self.streaming and not self.support_streaming()) else None, | |||||
| **kwargs | |||||
| ) | |||||
| except Exception as ex: | |||||
| raise self.handle_exceptions(ex) | |||||
| if isinstance(result.generations[0][0], ChatGeneration): | |||||
| completion_content = result.generations[0][0].message.content | |||||
| else: | |||||
| completion_content = result.generations[0][0].text | |||||
| if self.streaming and not self.support_streaming(): | |||||
| # use FakeLLM to simulate streaming when current model not support streaming but streaming is True | |||||
| prompts = self._get_prompt_from_messages(messages, ModelMode.CHAT) | |||||
| fake_llm = FakeLLM( | |||||
| response=completion_content, | |||||
| num_token_func=self.get_num_tokens, | |||||
| streaming=self.streaming, | |||||
| callbacks=callbacks | |||||
| ) | |||||
| fake_llm.generate([prompts]) | |||||
| if result.llm_output and result.llm_output['token_usage']: | |||||
| prompt_tokens = result.llm_output['token_usage']['prompt_tokens'] | |||||
| completion_tokens = result.llm_output['token_usage']['completion_tokens'] | |||||
| total_tokens = result.llm_output['token_usage']['total_tokens'] | |||||
| else: | |||||
| prompt_tokens = self.get_num_tokens(messages) | |||||
| completion_tokens = self.get_num_tokens([PromptMessage(content=completion_content, type=MessageType.ASSISTANT)]) | |||||
| total_tokens = prompt_tokens + completion_tokens | |||||
| if self.deduct_quota: | |||||
| self.model_provider.deduct_quota(total_tokens) | |||||
| return LLMRunResult( | |||||
| content=completion_content, | |||||
| prompt_tokens=prompt_tokens, | |||||
| completion_tokens=completion_tokens | |||||
| ) | |||||
| @abstractmethod | |||||
| def _run(self, messages: List[PromptMessage], | |||||
| stop: Optional[List[str]] = None, | |||||
| callbacks: Callbacks = None, | |||||
| **kwargs) -> LLMResult: | |||||
| """ | |||||
| run predict by prompt messages and stop words. | |||||
| :param messages: | |||||
| :param stop: | |||||
| :param callbacks: | |||||
| :return: | |||||
| """ | |||||
| raise NotImplementedError | |||||
| @abstractmethod | |||||
| def get_num_tokens(self, messages: List[PromptMessage]) -> int: | |||||
| """ | |||||
| get num tokens of prompt messages. | |||||
| :param messages: | |||||
| :return: | |||||
| """ | |||||
| raise NotImplementedError | |||||
| @abstractmethod | |||||
| def get_token_price(self, tokens: int, message_type: MessageType): | |||||
| """ | |||||
| get token price. | |||||
| :param tokens: | |||||
| :param message_type: | |||||
| :return: | |||||
| """ | |||||
| raise NotImplementedError | |||||
| @abstractmethod | |||||
| def get_currency(self): | |||||
| """ | |||||
| get token currency. | |||||
| :return: | |||||
| """ | |||||
| raise NotImplementedError | |||||
| def get_model_kwargs(self): | |||||
| return self.model_kwargs | |||||
| def set_model_kwargs(self, model_kwargs: ModelKwargs): | |||||
| self.model_kwargs = model_kwargs | |||||
| self._set_model_kwargs(model_kwargs) | |||||
| @abstractmethod | |||||
| def _set_model_kwargs(self, model_kwargs: ModelKwargs): | |||||
| raise NotImplementedError | |||||
| @abstractmethod | |||||
| def handle_exceptions(self, ex: Exception) -> Exception: | |||||
| """ | |||||
| Handle llm run exceptions. | |||||
| :param ex: | |||||
| :return: | |||||
| """ | |||||
| raise NotImplementedError | |||||
| def add_callbacks(self, callbacks: Callbacks): | |||||
| """ | |||||
| Add callbacks to client. | |||||
| :param callbacks: | |||||
| :return: | |||||
| """ | |||||
| if not self.client.callbacks: | |||||
| self.client.callbacks = callbacks | |||||
| else: | |||||
| self.client.callbacks.extend(callbacks) | |||||
| @classmethod | |||||
| def support_streaming(cls): | |||||
| return False | |||||
| def _get_prompt_from_messages(self, messages: List[PromptMessage], | |||||
| model_mode: Optional[ModelMode] = None) -> Union[str | List[BaseMessage]]: | |||||
| if len(messages) == 0: | |||||
| raise ValueError("prompt must not be empty.") | |||||
| if not model_mode: | |||||
| model_mode = self.model_mode | |||||
| if model_mode == ModelMode.COMPLETION: | |||||
| return messages[0].content | |||||
| else: | |||||
| chat_messages = [] | |||||
| for message in messages: | |||||
| if message.type == MessageType.HUMAN: | |||||
| chat_messages.append(HumanMessage(content=message.content)) | |||||
| elif message.type == MessageType.ASSISTANT: | |||||
| chat_messages.append(AIMessage(content=message.content)) | |||||
| elif message.type == MessageType.SYSTEM: | |||||
| chat_messages.append(SystemMessage(content=message.content)) | |||||
| return chat_messages | |||||
| def _to_model_kwargs_input(self, model_rules: ModelKwargsRules, model_kwargs: ModelKwargs) -> dict: | |||||
| """ | |||||
| convert model kwargs to provider model kwargs. | |||||
| :param model_rules: | |||||
| :param model_kwargs: | |||||
| :return: | |||||
| """ | |||||
| model_kwargs_input = {} | |||||
| for key, value in model_kwargs.dict().items(): | |||||
| rule = getattr(model_rules, key) | |||||
| if not rule.enabled: | |||||
| continue | |||||
| if rule.alias: | |||||
| key = rule.alias | |||||
| if rule.default is not None and value is None: | |||||
| value = rule.default | |||||
| if rule.min is not None: | |||||
| value = max(value, rule.min) | |||||
| if rule.max is not None: | |||||
| value = min(value, rule.max) | |||||
| model_kwargs_input[key] = value | |||||
| return model_kwargs_input |
| import decimal | |||||
| from typing import List, Optional, Any | |||||
| from langchain.callbacks.manager import Callbacks | |||||
| from langchain.llms import ChatGLM | |||||
| from langchain.schema import LLMResult | |||||
| from core.model_providers.error import LLMBadRequestError | |||||
| from core.model_providers.models.llm.base import BaseLLM | |||||
| from core.model_providers.models.entity.message import PromptMessage, MessageType | |||||
| from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs | |||||
| class ChatGLMModel(BaseLLM): | |||||
| model_mode: ModelMode = ModelMode.COMPLETION | |||||
| def _init_client(self) -> Any: | |||||
| provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs) | |||||
| return ChatGLM( | |||||
| callbacks=self.callbacks, | |||||
| endpoint_url=self.credentials.get('api_base'), | |||||
| **provider_model_kwargs | |||||
| ) | |||||
| def _run(self, messages: List[PromptMessage], | |||||
| stop: Optional[List[str]] = None, | |||||
| callbacks: Callbacks = None, | |||||
| **kwargs) -> LLMResult: | |||||
| """ | |||||
| run predict by prompt messages and stop words. | |||||
| :param messages: | |||||
| :param stop: | |||||
| :param callbacks: | |||||
| :return: | |||||
| """ | |||||
| prompts = self._get_prompt_from_messages(messages) | |||||
| return self._client.generate([prompts], stop, callbacks) | |||||
| def get_num_tokens(self, messages: List[PromptMessage]) -> int: | |||||
| """ | |||||
| get num tokens of prompt messages. | |||||
| :param messages: | |||||
| :return: | |||||
| """ | |||||
| prompts = self._get_prompt_from_messages(messages) | |||||
| return max(self._client.get_num_tokens(prompts), 0) | |||||
| def get_token_price(self, tokens: int, message_type: MessageType): | |||||
| return decimal.Decimal('0') | |||||
| def get_currency(self): | |||||
| return 'RMB' | |||||
| def _set_model_kwargs(self, model_kwargs: ModelKwargs): | |||||
| provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs) | |||||
| for k, v in provider_model_kwargs.items(): | |||||
| if hasattr(self.client, k): | |||||
| setattr(self.client, k, v) | |||||
| def handle_exceptions(self, ex: Exception) -> Exception: | |||||
| if isinstance(ex, ValueError): | |||||
| return LLMBadRequestError(f"ChatGLM: {str(ex)}") | |||||
| else: | |||||
| return ex | |||||
| @classmethod | |||||
| def support_streaming(cls): | |||||
| return False |
| import decimal | |||||
| from functools import wraps | |||||
| from typing import List, Optional, Any | |||||
| from langchain import HuggingFaceHub | |||||
| from langchain.callbacks.manager import Callbacks | |||||
| from langchain.llms import HuggingFaceEndpoint | |||||
| from langchain.schema import LLMResult | |||||
| from core.model_providers.error import LLMBadRequestError | |||||
| from core.model_providers.models.llm.base import BaseLLM | |||||
| from core.model_providers.models.entity.message import PromptMessage, MessageType | |||||
| from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs | |||||
| class HuggingfaceHubModel(BaseLLM): | |||||
| model_mode: ModelMode = ModelMode.COMPLETION | |||||
| def _init_client(self) -> Any: | |||||
| provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs) | |||||
| if self.credentials['huggingfacehub_api_type'] == 'inference_endpoints': | |||||
| client = HuggingFaceEndpoint( | |||||
| endpoint_url=self.credentials['huggingfacehub_endpoint_url'], | |||||
| task='text2text-generation', | |||||
| model_kwargs=provider_model_kwargs, | |||||
| huggingfacehub_api_token=self.credentials['huggingfacehub_api_token'], | |||||
| callbacks=self.callbacks, | |||||
| ) | |||||
| else: | |||||
| client = HuggingFaceHub( | |||||
| repo_id=self.name, | |||||
| task=self.credentials['task_type'], | |||||
| model_kwargs=provider_model_kwargs, | |||||
| huggingfacehub_api_token=self.credentials['huggingfacehub_api_token'], | |||||
| callbacks=self.callbacks, | |||||
| ) | |||||
| return client | |||||
| def _run(self, messages: List[PromptMessage], | |||||
| stop: Optional[List[str]] = None, | |||||
| callbacks: Callbacks = None, | |||||
| **kwargs) -> LLMResult: | |||||
| """ | |||||
| run predict by prompt messages and stop words. | |||||
| :param messages: | |||||
| :param stop: | |||||
| :param callbacks: | |||||
| :return: | |||||
| """ | |||||
| prompts = self._get_prompt_from_messages(messages) | |||||
| return self._client.generate([prompts], stop, callbacks) | |||||
| def get_num_tokens(self, messages: List[PromptMessage]) -> int: | |||||
| """ | |||||
| get num tokens of prompt messages. | |||||
| :param messages: | |||||
| :return: | |||||
| """ | |||||
| prompts = self._get_prompt_from_messages(messages) | |||||
| return self._client.get_num_tokens(prompts) | |||||
| def get_token_price(self, tokens: int, message_type: MessageType): | |||||
| # not support calc price | |||||
| return decimal.Decimal('0') | |||||
| def get_currency(self): | |||||
| return 'USD' | |||||
| def _set_model_kwargs(self, model_kwargs: ModelKwargs): | |||||
| provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs) | |||||
| self.client.model_kwargs = provider_model_kwargs | |||||
| def handle_exceptions(self, ex: Exception) -> Exception: | |||||
| return LLMBadRequestError(f"Huggingface Hub: {str(ex)}") | |||||
| @classmethod | |||||
| def support_streaming(cls): | |||||
| return False | |||||
| import decimal | |||||
| from typing import List, Optional, Any | |||||
| from langchain.callbacks.manager import Callbacks | |||||
| from langchain.llms import Minimax | |||||
| from langchain.schema import LLMResult | |||||
| from core.model_providers.error import LLMBadRequestError | |||||
| from core.model_providers.models.llm.base import BaseLLM | |||||
| from core.model_providers.models.entity.message import PromptMessage, MessageType | |||||
| from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs | |||||
| class MinimaxModel(BaseLLM): | |||||
| model_mode: ModelMode = ModelMode.COMPLETION | |||||
| def _init_client(self) -> Any: | |||||
| provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs) | |||||
| return Minimax( | |||||
| model=self.name, | |||||
| model_kwargs={ | |||||
| 'stream': False | |||||
| }, | |||||
| callbacks=self.callbacks, | |||||
| **self.credentials, | |||||
| **provider_model_kwargs | |||||
| ) | |||||
| def _run(self, messages: List[PromptMessage], | |||||
| stop: Optional[List[str]] = None, | |||||
| callbacks: Callbacks = None, | |||||
| **kwargs) -> LLMResult: | |||||
| """ | |||||
| run predict by prompt messages and stop words. | |||||
| :param messages: | |||||
| :param stop: | |||||
| :param callbacks: | |||||
| :return: | |||||
| """ | |||||
| prompts = self._get_prompt_from_messages(messages) | |||||
| return self._client.generate([prompts], stop, callbacks) | |||||
| def get_num_tokens(self, messages: List[PromptMessage]) -> int: | |||||
| """ | |||||
| get num tokens of prompt messages. | |||||
| :param messages: | |||||
| :return: | |||||
| """ | |||||
| prompts = self._get_prompt_from_messages(messages) | |||||
| return max(self._client.get_num_tokens(prompts), 0) | |||||
| def get_token_price(self, tokens: int, message_type: MessageType): | |||||
| return decimal.Decimal('0') | |||||
| def get_currency(self): | |||||
| return 'RMB' | |||||
| def _set_model_kwargs(self, model_kwargs: ModelKwargs): | |||||
| provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs) | |||||
| for k, v in provider_model_kwargs.items(): | |||||
| if hasattr(self.client, k): | |||||
| setattr(self.client, k, v) | |||||
| def handle_exceptions(self, ex: Exception) -> Exception: | |||||
| if isinstance(ex, ValueError): | |||||
| return LLMBadRequestError(f"Minimax: {str(ex)}") | |||||
| else: | |||||
| return ex |
| import decimal | |||||
| import logging | |||||
| from typing import List, Optional, Any | |||||
| import openai | |||||
| from langchain.callbacks.manager import Callbacks | |||||
| from langchain.schema import LLMResult | |||||
| from core.model_providers.providers.base import BaseModelProvider | |||||
| from core.third_party.langchain.llms.chat_open_ai import EnhanceChatOpenAI | |||||
| from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, \ | |||||
| LLMRateLimitError, LLMAuthorizationError, ModelCurrentlyNotSupportError | |||||
| from core.third_party.langchain.llms.open_ai import EnhanceOpenAI | |||||
| from core.model_providers.models.llm.base import BaseLLM | |||||
| from core.model_providers.models.entity.message import PromptMessage, MessageType | |||||
| from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs | |||||
| from models.provider import ProviderType, ProviderQuotaType | |||||
| COMPLETION_MODELS = [ | |||||
| 'text-davinci-003', # 4,097 tokens | |||||
| ] | |||||
| CHAT_MODELS = [ | |||||
| 'gpt-4', # 8,192 tokens | |||||
| 'gpt-4-32k', # 32,768 tokens | |||||
| 'gpt-3.5-turbo', # 4,096 tokens | |||||
| 'gpt-3.5-turbo-16k', # 16,384 tokens | |||||
| ] | |||||
| MODEL_MAX_TOKENS = { | |||||
| 'gpt-4': 8192, | |||||
| 'gpt-4-32k': 32768, | |||||
| 'gpt-3.5-turbo': 4096, | |||||
| 'gpt-3.5-turbo-16k': 16384, | |||||
| 'text-davinci-003': 4097, | |||||
| } | |||||
| class OpenAIModel(BaseLLM): | |||||
| def __init__(self, model_provider: BaseModelProvider, | |||||
| name: str, | |||||
| model_kwargs: ModelKwargs, | |||||
| streaming: bool = False, | |||||
| callbacks: Callbacks = None): | |||||
| if name in COMPLETION_MODELS: | |||||
| self.model_mode = ModelMode.COMPLETION | |||||
| else: | |||||
| self.model_mode = ModelMode.CHAT | |||||
| super().__init__(model_provider, name, model_kwargs, streaming, callbacks) | |||||
| def _init_client(self) -> Any: | |||||
| provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs) | |||||
| if self.name in COMPLETION_MODELS: | |||||
| client = EnhanceOpenAI( | |||||
| model_name=self.name, | |||||
| streaming=self.streaming, | |||||
| callbacks=self.callbacks, | |||||
| request_timeout=60, | |||||
| **self.credentials, | |||||
| **provider_model_kwargs | |||||
| ) | |||||
| else: | |||||
| # Fine-tuning is currently only available for the following base models: | |||||
| # davinci, curie, babbage, and ada. | |||||
| # This means that except for the fixed `completion` model, | |||||
| # all other fine-tuned models are `completion` models. | |||||
| extra_model_kwargs = { | |||||
| 'top_p': provider_model_kwargs.get('top_p'), | |||||
| 'frequency_penalty': provider_model_kwargs.get('frequency_penalty'), | |||||
| 'presence_penalty': provider_model_kwargs.get('presence_penalty'), | |||||
| } | |||||
| client = EnhanceChatOpenAI( | |||||
| model_name=self.name, | |||||
| temperature=provider_model_kwargs.get('temperature'), | |||||
| max_tokens=provider_model_kwargs.get('max_tokens'), | |||||
| model_kwargs=extra_model_kwargs, | |||||
| streaming=self.streaming, | |||||
| callbacks=self.callbacks, | |||||
| request_timeout=60, | |||||
| **self.credentials | |||||
| ) | |||||
| return client | |||||
| def _run(self, messages: List[PromptMessage], | |||||
| stop: Optional[List[str]] = None, | |||||
| callbacks: Callbacks = None, | |||||
| **kwargs) -> LLMResult: | |||||
| """ | |||||
| run predict by prompt messages and stop words. | |||||
| :param messages: | |||||
| :param stop: | |||||
| :param callbacks: | |||||
| :return: | |||||
| """ | |||||
| if self.name == 'gpt-4' \ | |||||
| and self.model_provider.provider.provider_type == ProviderType.SYSTEM.value \ | |||||
| and self.model_provider.provider.quota_type == ProviderQuotaType.TRIAL.value: | |||||
| raise ModelCurrentlyNotSupportError("Dify Hosted OpenAI GPT-4 currently not support.") | |||||
| prompts = self._get_prompt_from_messages(messages) | |||||
| return self._client.generate([prompts], stop, callbacks) | |||||
| def get_num_tokens(self, messages: List[PromptMessage]) -> int: | |||||
| """ | |||||
| get num tokens of prompt messages. | |||||
| :param messages: | |||||
| :return: | |||||
| """ | |||||
| prompts = self._get_prompt_from_messages(messages) | |||||
| if isinstance(prompts, str): | |||||
| return self._client.get_num_tokens(prompts) | |||||
| else: | |||||
| return max(self._client.get_num_tokens_from_messages(prompts) - len(prompts), 0) | |||||
| def get_token_price(self, tokens: int, message_type: MessageType): | |||||
| model_unit_prices = { | |||||
| 'gpt-4': { | |||||
| 'prompt': decimal.Decimal('0.03'), | |||||
| 'completion': decimal.Decimal('0.06'), | |||||
| }, | |||||
| 'gpt-4-32k': { | |||||
| 'prompt': decimal.Decimal('0.06'), | |||||
| 'completion': decimal.Decimal('0.12') | |||||
| }, | |||||
| 'gpt-3.5-turbo': { | |||||
| 'prompt': decimal.Decimal('0.0015'), | |||||
| 'completion': decimal.Decimal('0.002') | |||||
| }, | |||||
| 'gpt-3.5-turbo-16k': { | |||||
| 'prompt': decimal.Decimal('0.003'), | |||||
| 'completion': decimal.Decimal('0.004') | |||||
| }, | |||||
| 'text-davinci-003': { | |||||
| 'prompt': decimal.Decimal('0.02'), | |||||
| 'completion': decimal.Decimal('0.02') | |||||
| }, | |||||
| } | |||||
| if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM: | |||||
| unit_price = model_unit_prices[self.name]['prompt'] | |||||
| else: | |||||
| unit_price = model_unit_prices[self.name]['completion'] | |||||
| tokens_per_1k = (decimal.Decimal(tokens) / 1000).quantize(decimal.Decimal('0.001'), | |||||
| rounding=decimal.ROUND_HALF_UP) | |||||
| total_price = tokens_per_1k * unit_price | |||||
| return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP) | |||||
| def get_currency(self): | |||||
| return 'USD' | |||||
| def _set_model_kwargs(self, model_kwargs: ModelKwargs): | |||||
| provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs) | |||||
| if self.name in COMPLETION_MODELS: | |||||
| for k, v in provider_model_kwargs.items(): | |||||
| if hasattr(self.client, k): | |||||
| setattr(self.client, k, v) | |||||
| else: | |||||
| extra_model_kwargs = { | |||||
| 'top_p': provider_model_kwargs.get('top_p'), | |||||
| 'frequency_penalty': provider_model_kwargs.get('frequency_penalty'), | |||||
| 'presence_penalty': provider_model_kwargs.get('presence_penalty'), | |||||
| } | |||||
| self.client.temperature = provider_model_kwargs.get('temperature') | |||||
| self.client.max_tokens = provider_model_kwargs.get('max_tokens') | |||||
| self.client.model_kwargs = extra_model_kwargs | |||||
| def handle_exceptions(self, ex: Exception) -> Exception: | |||||
| if isinstance(ex, openai.error.InvalidRequestError): | |||||
| logging.warning("Invalid request to OpenAI API.") | |||||
| return LLMBadRequestError(str(ex)) | |||||
| elif isinstance(ex, openai.error.APIConnectionError): | |||||
| logging.warning("Failed to connect to OpenAI API.") | |||||
| return LLMAPIConnectionError(ex.__class__.__name__ + ":" + str(ex)) | |||||
| elif isinstance(ex, (openai.error.APIError, openai.error.ServiceUnavailableError, openai.error.Timeout)): | |||||
| logging.warning("OpenAI service unavailable.") | |||||
| return LLMAPIUnavailableError(ex.__class__.__name__ + ":" + str(ex)) | |||||
| elif isinstance(ex, openai.error.RateLimitError): | |||||
| return LLMRateLimitError(str(ex)) | |||||
| elif isinstance(ex, openai.error.AuthenticationError): | |||||
| raise LLMAuthorizationError(str(ex)) | |||||
| elif isinstance(ex, openai.error.OpenAIError): | |||||
| return LLMBadRequestError(ex.__class__.__name__ + ":" + str(ex)) | |||||
| else: | |||||
| return ex | |||||
| @classmethod | |||||
| def support_streaming(cls): | |||||
| return True | |||||
| # def is_model_valid_or_raise(self): | |||||
| # """ | |||||
| # check is a valid model. | |||||
| # | |||||
| # :return: | |||||
| # """ | |||||
| # credentials = self._model_provider.get_credentials() | |||||
| # | |||||
| # try: | |||||
| # result = openai.Model.retrieve( | |||||
| # id=self.name, | |||||
| # api_key=credentials.get('openai_api_key'), | |||||
| # request_timeout=60 | |||||
| # ) | |||||
| # | |||||
| # if 'id' not in result or result['id'] != self.name: | |||||
| # raise LLMNotExistsError(f"OpenAI Model {self.name} not exists.") | |||||
| # except openai.error.OpenAIError as e: | |||||
| # raise LLMNotExistsError(f"OpenAI Model {self.name} not exists, cause: {e.__class__.__name__}:{str(e)}") | |||||
| # except Exception as e: | |||||
| # logging.exception("OpenAI Model retrieve failed.") | |||||
| # raise e |
| import decimal | |||||
| from functools import wraps | |||||
| from typing import List, Optional, Any | |||||
| from langchain.callbacks.manager import Callbacks | |||||
| from langchain.schema import LLMResult, get_buffer_string | |||||
| from replicate.exceptions import ReplicateError, ModelError | |||||
| from core.model_providers.providers.base import BaseModelProvider | |||||
| from core.model_providers.error import LLMBadRequestError | |||||
| from core.third_party.langchain.llms.replicate_llm import EnhanceReplicate | |||||
| from core.model_providers.models.llm.base import BaseLLM | |||||
| from core.model_providers.models.entity.message import PromptMessage, MessageType | |||||
| from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs | |||||
| class ReplicateModel(BaseLLM): | |||||
| def __init__(self, model_provider: BaseModelProvider, | |||||
| name: str, | |||||
| model_kwargs: ModelKwargs, | |||||
| streaming: bool = False, | |||||
| callbacks: Callbacks = None): | |||||
| self.model_mode = ModelMode.CHAT if name.endswith('-chat') else ModelMode.COMPLETION | |||||
| super().__init__(model_provider, name, model_kwargs, streaming, callbacks) | |||||
| def _init_client(self) -> Any: | |||||
| provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs) | |||||
| return EnhanceReplicate( | |||||
| model=self.name + ':' + self.credentials.get('model_version'), | |||||
| input=provider_model_kwargs, | |||||
| streaming=self.streaming, | |||||
| replicate_api_token=self.credentials.get('replicate_api_token'), | |||||
| callbacks=self.callbacks, | |||||
| ) | |||||
| def _run(self, messages: List[PromptMessage], | |||||
| stop: Optional[List[str]] = None, | |||||
| callbacks: Callbacks = None, | |||||
| **kwargs) -> LLMResult: | |||||
| """ | |||||
| run predict by prompt messages and stop words. | |||||
| :param messages: | |||||
| :param stop: | |||||
| :param callbacks: | |||||
| :return: | |||||
| """ | |||||
| prompts = self._get_prompt_from_messages(messages) | |||||
| extra_kwargs = {} | |||||
| if isinstance(prompts, list): | |||||
| system_messages = [message for message in messages if message.type == 'system'] | |||||
| if system_messages: | |||||
| system_message = system_messages[0] | |||||
| extra_kwargs['system_prompt'] = system_message.content | |||||
| prompts = [message for message in messages if message.type != 'system'] | |||||
| prompts = get_buffer_string(prompts) | |||||
| # The maximum length the generated tokens can have. | |||||
| # Corresponds to the length of the input prompt + max_new_tokens. | |||||
| if 'max_length' in self._client.input: | |||||
| self._client.input['max_length'] = min( | |||||
| self._client.input['max_length'] + self.get_num_tokens(messages), | |||||
| self.model_rules.max_tokens.max | |||||
| ) | |||||
| return self._client.generate([prompts], stop, callbacks, **extra_kwargs) | |||||
| def get_num_tokens(self, messages: List[PromptMessage]) -> int: | |||||
| """ | |||||
| get num tokens of prompt messages. | |||||
| :param messages: | |||||
| :return: | |||||
| """ | |||||
| prompts = self._get_prompt_from_messages(messages) | |||||
| if isinstance(prompts, list): | |||||
| prompts = get_buffer_string(prompts) | |||||
| return self._client.get_num_tokens(prompts) | |||||
| def get_token_price(self, tokens: int, message_type: MessageType): | |||||
| # replicate only pay for prediction seconds | |||||
| return decimal.Decimal('0') | |||||
| def get_currency(self): | |||||
| return 'USD' | |||||
| def _set_model_kwargs(self, model_kwargs: ModelKwargs): | |||||
| provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs) | |||||
| self.client.input = provider_model_kwargs | |||||
| def handle_exceptions(self, ex: Exception) -> Exception: | |||||
| if isinstance(ex, (ModelError, ReplicateError)): | |||||
| return LLMBadRequestError(f"Replicate: {str(ex)}") | |||||
| else: | |||||
| return ex | |||||
| @classmethod | |||||
| def support_streaming(cls): | |||||
| return True |