| @@ -19,7 +19,8 @@ def check_file_for_chinese_comments(file_path): | |||
| def main(): | |||
| has_chinese = False | |||
| excluded_files = ["model_template.py", 'stopwords.py', 'commands.py', 'indexing_runner.py', 'web_reader_tool.py'] | |||
| excluded_files = ["model_template.py", 'stopwords.py', 'commands.py', | |||
| 'indexing_runner.py', 'web_reader_tool.py', 'spark_provider.py'] | |||
| for root, _, files in os.walk("."): | |||
| for file in files: | |||
| @@ -102,3 +102,29 @@ NOTION_INTEGRATION_TYPE=public | |||
| NOTION_CLIENT_SECRET=you-client-secret | |||
| NOTION_CLIENT_ID=you-client-id | |||
| NOTION_INTERNAL_SECRET=you-internal-secret | |||
| # Hosted Model Credentials | |||
| HOSTED_OPENAI_ENABLED=false | |||
| HOSTED_OPENAI_API_KEY= | |||
| HOSTED_OPENAI_API_BASE= | |||
| HOSTED_OPENAI_API_ORGANIZATION= | |||
| HOSTED_OPENAI_QUOTA_LIMIT=200 | |||
| HOSTED_OPENAI_PAID_ENABLED=false | |||
| HOSTED_OPENAI_PAID_STRIPE_PRICE_ID= | |||
| HOSTED_OPENAI_PAID_INCREASE_QUOTA=1 | |||
| HOSTED_AZURE_OPENAI_ENABLED=false | |||
| HOSTED_AZURE_OPENAI_API_KEY= | |||
| HOSTED_AZURE_OPENAI_API_BASE= | |||
| HOSTED_AZURE_OPENAI_QUOTA_LIMIT=200 | |||
| HOSTED_ANTHROPIC_ENABLED=false | |||
| HOSTED_ANTHROPIC_API_BASE= | |||
| HOSTED_ANTHROPIC_API_KEY= | |||
| HOSTED_ANTHROPIC_QUOTA_LIMIT=1000000 | |||
| HOSTED_ANTHROPIC_PAID_ENABLED=false | |||
| HOSTED_ANTHROPIC_PAID_STRIPE_PRICE_ID= | |||
| HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA=1 | |||
| STRIPE_API_KEY= | |||
| STRIPE_WEBHOOK_SECRET= | |||
| @@ -16,8 +16,9 @@ from flask import Flask, request, Response, session | |||
| import flask_login | |||
| from flask_cors import CORS | |||
| from core.model_providers.providers import hosted | |||
| from extensions import ext_session, ext_celery, ext_sentry, ext_redis, ext_login, ext_migrate, \ | |||
| ext_database, ext_storage, ext_mail | |||
| ext_database, ext_storage, ext_mail, ext_stripe | |||
| from extensions.ext_database import db | |||
| from extensions.ext_login import login_manager | |||
| @@ -71,7 +72,7 @@ def create_app(test_config=None) -> Flask: | |||
| register_blueprints(app) | |||
| register_commands(app) | |||
| core.init_app(app) | |||
| hosted.init_app(app) | |||
| return app | |||
| @@ -88,6 +89,7 @@ def initialize_extensions(app): | |||
| ext_login.init_app(app) | |||
| ext_mail.init_app(app) | |||
| ext_sentry.init_app(app) | |||
| ext_stripe.init_app(app) | |||
| def _create_tenant_for_account(account): | |||
| @@ -246,5 +248,18 @@ def threads(): | |||
| } | |||
| @app.route('/db-pool-stat') | |||
| def pool_stat(): | |||
| engine = db.engine | |||
| return { | |||
| 'pool_size': engine.pool.size(), | |||
| 'checked_in_connections': engine.pool.checkedin(), | |||
| 'checked_out_connections': engine.pool.checkedout(), | |||
| 'overflow_connections': engine.pool.overflow(), | |||
| 'connection_timeout': engine.pool.timeout(), | |||
| 'recycle_time': db.engine.pool._recycle | |||
| } | |||
| if __name__ == '__main__': | |||
| app.run(host='0.0.0.0', port=5001) | |||
| @@ -1,5 +1,5 @@ | |||
| import datetime | |||
| import logging | |||
| import math | |||
| import random | |||
| import string | |||
| import time | |||
| @@ -9,18 +9,18 @@ from flask import current_app | |||
| from werkzeug.exceptions import NotFound | |||
| from core.index.index import IndexBuilder | |||
| from core.model_providers.providers.hosted import hosted_model_providers | |||
| from libs.password import password_pattern, valid_password, hash_password | |||
| from libs.helper import email as email_validate | |||
| from extensions.ext_database import db | |||
| from libs.rsa import generate_key_pair | |||
| from models.account import InvitationCode, Tenant | |||
| from models.dataset import Dataset, DatasetQuery, Document, DocumentSegment | |||
| from models.dataset import Dataset, DatasetQuery, Document | |||
| from models.model import Account | |||
| import secrets | |||
| import base64 | |||
| from models.provider import Provider, ProviderName | |||
| from services.provider_service import ProviderService | |||
| from models.provider import Provider, ProviderType, ProviderQuotaType | |||
| @click.command('reset-password', help='Reset the account password.') | |||
| @@ -251,26 +251,37 @@ def clean_unused_dataset_indexes(): | |||
| @click.command('sync-anthropic-hosted-providers', help='Sync anthropic hosted providers.') | |||
| def sync_anthropic_hosted_providers(): | |||
| if not hosted_model_providers.anthropic: | |||
| click.echo(click.style('Anthropic hosted provider is not configured.', fg='red')) | |||
| return | |||
| click.echo(click.style('Start sync anthropic hosted providers.', fg='green')) | |||
| count = 0 | |||
| page = 1 | |||
| while True: | |||
| try: | |||
| tenants = db.session.query(Tenant).order_by(Tenant.created_at.desc()).paginate(page=page, per_page=50) | |||
| providers = db.session.query(Provider).filter( | |||
| Provider.provider_name == 'anthropic', | |||
| Provider.provider_type == ProviderType.SYSTEM.value, | |||
| Provider.quota_type == ProviderQuotaType.TRIAL.value, | |||
| ).order_by(Provider.created_at.desc()).paginate(page=page, per_page=100) | |||
| except NotFound: | |||
| break | |||
| page += 1 | |||
| for tenant in tenants: | |||
| for provider in providers: | |||
| try: | |||
| click.echo('Syncing tenant anthropic hosted provider: {}'.format(tenant.id)) | |||
| ProviderService.create_system_provider( | |||
| tenant, | |||
| ProviderName.ANTHROPIC.value, | |||
| current_app.config['ANTHROPIC_HOSTED_QUOTA_LIMIT'], | |||
| True | |||
| ) | |||
| click.echo('Syncing tenant anthropic hosted provider: {}'.format(provider.tenant_id)) | |||
| original_quota_limit = provider.quota_limit | |||
| new_quota_limit = hosted_model_providers.anthropic.quota_limit | |||
| division = math.ceil(new_quota_limit / 1000) | |||
| provider.quota_limit = new_quota_limit if original_quota_limit == 1000 \ | |||
| else original_quota_limit * division | |||
| provider.quota_used = division * provider.quota_used | |||
| db.session.commit() | |||
| count += 1 | |||
| except Exception as e: | |||
| click.echo(click.style( | |||
| @@ -41,6 +41,7 @@ DEFAULTS = { | |||
| 'SESSION_USE_SIGNER': 'True', | |||
| 'DEPLOY_ENV': 'PRODUCTION', | |||
| 'SQLALCHEMY_POOL_SIZE': 30, | |||
| 'SQLALCHEMY_POOL_RECYCLE': 3600, | |||
| 'SQLALCHEMY_ECHO': 'False', | |||
| 'SENTRY_TRACES_SAMPLE_RATE': 1.0, | |||
| 'SENTRY_PROFILES_SAMPLE_RATE': 1.0, | |||
| @@ -50,9 +51,16 @@ DEFAULTS = { | |||
| 'PDF_PREVIEW': 'True', | |||
| 'LOG_LEVEL': 'INFO', | |||
| 'DISABLE_PROVIDER_CONFIG_VALIDATION': 'False', | |||
| 'DEFAULT_LLM_PROVIDER': 'openai', | |||
| 'OPENAI_HOSTED_QUOTA_LIMIT': 200, | |||
| 'ANTHROPIC_HOSTED_QUOTA_LIMIT': 1000, | |||
| 'HOSTED_OPENAI_QUOTA_LIMIT': 200, | |||
| 'HOSTED_OPENAI_ENABLED': 'False', | |||
| 'HOSTED_OPENAI_PAID_ENABLED': 'False', | |||
| 'HOSTED_OPENAI_PAID_INCREASE_QUOTA': 1, | |||
| 'HOSTED_AZURE_OPENAI_ENABLED': 'False', | |||
| 'HOSTED_AZURE_OPENAI_QUOTA_LIMIT': 200, | |||
| 'HOSTED_ANTHROPIC_QUOTA_LIMIT': 1000000, | |||
| 'HOSTED_ANTHROPIC_ENABLED': 'False', | |||
| 'HOSTED_ANTHROPIC_PAID_ENABLED': 'False', | |||
| 'HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA': 1, | |||
| 'TENANT_DOCUMENT_COUNT': 100, | |||
| 'CLEAN_DAY_SETTING': 30 | |||
| } | |||
| @@ -182,7 +190,10 @@ class Config: | |||
| } | |||
| self.SQLALCHEMY_DATABASE_URI = f"postgresql://{db_credentials['DB_USERNAME']}:{db_credentials['DB_PASSWORD']}@{db_credentials['DB_HOST']}:{db_credentials['DB_PORT']}/{db_credentials['DB_DATABASE']}" | |||
| self.SQLALCHEMY_ENGINE_OPTIONS = {'pool_size': int(get_env('SQLALCHEMY_POOL_SIZE'))} | |||
| self.SQLALCHEMY_ENGINE_OPTIONS = { | |||
| 'pool_size': int(get_env('SQLALCHEMY_POOL_SIZE')), | |||
| 'pool_recycle': int(get_env('SQLALCHEMY_POOL_RECYCLE')) | |||
| } | |||
| self.SQLALCHEMY_ECHO = get_bool_env('SQLALCHEMY_ECHO') | |||
| @@ -194,20 +205,35 @@ class Config: | |||
| self.BROKER_USE_SSL = self.CELERY_BROKER_URL.startswith('rediss://') | |||
| # hosted provider credentials | |||
| self.OPENAI_API_KEY = get_env('OPENAI_API_KEY') | |||
| self.ANTHROPIC_API_KEY = get_env('ANTHROPIC_API_KEY') | |||
| self.OPENAI_HOSTED_QUOTA_LIMIT = get_env('OPENAI_HOSTED_QUOTA_LIMIT') | |||
| self.ANTHROPIC_HOSTED_QUOTA_LIMIT = get_env('ANTHROPIC_HOSTED_QUOTA_LIMIT') | |||
| self.HOSTED_OPENAI_ENABLED = get_bool_env('HOSTED_OPENAI_ENABLED') | |||
| self.HOSTED_OPENAI_API_KEY = get_env('HOSTED_OPENAI_API_KEY') | |||
| self.HOSTED_OPENAI_API_BASE = get_env('HOSTED_OPENAI_API_BASE') | |||
| self.HOSTED_OPENAI_API_ORGANIZATION = get_env('HOSTED_OPENAI_API_ORGANIZATION') | |||
| self.HOSTED_OPENAI_QUOTA_LIMIT = get_env('HOSTED_OPENAI_QUOTA_LIMIT') | |||
| self.HOSTED_OPENAI_PAID_ENABLED = get_bool_env('HOSTED_OPENAI_PAID_ENABLED') | |||
| self.HOSTED_OPENAI_PAID_STRIPE_PRICE_ID = get_env('HOSTED_OPENAI_PAID_STRIPE_PRICE_ID') | |||
| self.HOSTED_OPENAI_PAID_INCREASE_QUOTA = int(get_env('HOSTED_OPENAI_PAID_INCREASE_QUOTA')) | |||
| self.HOSTED_AZURE_OPENAI_ENABLED = get_bool_env('HOSTED_AZURE_OPENAI_ENABLED') | |||
| self.HOSTED_AZURE_OPENAI_API_KEY = get_env('HOSTED_AZURE_OPENAI_API_KEY') | |||
| self.HOSTED_AZURE_OPENAI_API_BASE = get_env('HOSTED_AZURE_OPENAI_API_BASE') | |||
| self.HOSTED_AZURE_OPENAI_QUOTA_LIMIT = get_env('HOSTED_AZURE_OPENAI_QUOTA_LIMIT') | |||
| self.HOSTED_ANTHROPIC_ENABLED = get_bool_env('HOSTED_ANTHROPIC_ENABLED') | |||
| self.HOSTED_ANTHROPIC_API_BASE = get_env('HOSTED_ANTHROPIC_API_BASE') | |||
| self.HOSTED_ANTHROPIC_API_KEY = get_env('HOSTED_ANTHROPIC_API_KEY') | |||
| self.HOSTED_ANTHROPIC_QUOTA_LIMIT = get_env('HOSTED_ANTHROPIC_QUOTA_LIMIT') | |||
| self.HOSTED_ANTHROPIC_PAID_ENABLED = get_bool_env('HOSTED_ANTHROPIC_PAID_ENABLED') | |||
| self.HOSTED_ANTHROPIC_PAID_STRIPE_PRICE_ID = get_env('HOSTED_ANTHROPIC_PAID_STRIPE_PRICE_ID') | |||
| self.HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA = get_env('HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA') | |||
| self.STRIPE_API_KEY = get_env('STRIPE_API_KEY') | |||
| self.STRIPE_WEBHOOK_SECRET = get_env('STRIPE_WEBHOOK_SECRET') | |||
| # By default it is False | |||
| # You could disable it for compatibility with certain OpenAPI providers | |||
| self.DISABLE_PROVIDER_CONFIG_VALIDATION = get_bool_env('DISABLE_PROVIDER_CONFIG_VALIDATION') | |||
| # For temp use only | |||
| # set default LLM provider, default is 'openai', support `azure_openai` | |||
| self.DEFAULT_LLM_PROVIDER = get_env('DEFAULT_LLM_PROVIDER') | |||
| # notion import setting | |||
| self.NOTION_CLIENT_ID = get_env('NOTION_CLIENT_ID') | |||
| self.NOTION_CLIENT_SECRET = get_env('NOTION_CLIENT_SECRET') | |||
| @@ -18,10 +18,13 @@ from .auth import login, oauth, data_source_oauth, activate | |||
| from .datasets import datasets, datasets_document, datasets_segments, file, hit_testing, data_source | |||
| # Import workspace controllers | |||
| from .workspace import workspace, members, model_providers, account, tool_providers | |||
| from .workspace import workspace, members, providers, model_providers, account, tool_providers, models | |||
| # Import explore controllers | |||
| from .explore import installed_app, recommended_app, completion, conversation, message, parameter, saved_message, audio | |||
| # Import universal chat controllers | |||
| from .universal_chat import chat, conversation, message, parameter, audio | |||
| # Import webhook controllers | |||
| from .webhook import stripe | |||
| @@ -2,16 +2,17 @@ | |||
| import json | |||
| from datetime import datetime | |||
| import flask | |||
| from flask_login import login_required, current_user | |||
| from flask_restful import Resource, reqparse, fields, marshal_with, abort, inputs | |||
| from werkzeug.exceptions import Unauthorized, Forbidden | |||
| from werkzeug.exceptions import Forbidden | |||
| from constants.model_template import model_templates, demo_model_templates | |||
| from controllers.console import api | |||
| from controllers.console.app.error import AppNotFoundError | |||
| from controllers.console.app.error import AppNotFoundError, ProviderNotInitializeError | |||
| from controllers.console.setup import setup_required | |||
| from controllers.console.wraps import account_initialization_required | |||
| from core.model_providers.model_factory import ModelFactory | |||
| from core.model_providers.models.entity.model_params import ModelType | |||
| from events.app_event import app_was_created, app_was_deleted | |||
| from libs.helper import TimestampField | |||
| from extensions.ext_database import db | |||
| @@ -126,9 +127,9 @@ class AppListApi(Resource): | |||
| if args['model_config'] is not None: | |||
| # validate config | |||
| model_configuration = AppModelConfigService.validate_configuration( | |||
| tenant_id=current_user.current_tenant_id, | |||
| account=current_user, | |||
| config=args['model_config'], | |||
| mode=args['mode'] | |||
| config=args['model_config'] | |||
| ) | |||
| app = App( | |||
| @@ -164,6 +165,21 @@ class AppListApi(Resource): | |||
| app = App(**model_config_template['app']) | |||
| app_model_config = AppModelConfig(**model_config_template['model_config']) | |||
| default_model = ModelFactory.get_default_model( | |||
| tenant_id=current_user.current_tenant_id, | |||
| model_type=ModelType.TEXT_GENERATION | |||
| ) | |||
| if default_model: | |||
| model_dict = app_model_config.model_dict | |||
| model_dict['provider'] = default_model.provider_name | |||
| model_dict['name'] = default_model.model_name | |||
| app_model_config.model = json.dumps(model_dict) | |||
| else: | |||
| raise ProviderNotInitializeError( | |||
| f"No Text Generation Model available. Please configure a valid provider " | |||
| f"in the Settings -> Model Provider.") | |||
| app.name = args['name'] | |||
| app.mode = args['mode'] | |||
| app.icon = args['icon'] | |||
| @@ -14,7 +14,7 @@ from controllers.console.app.error import AppUnavailableError, \ | |||
| UnsupportedAudioTypeError, ProviderNotSupportSpeechToTextError | |||
| from controllers.console.setup import setup_required | |||
| from controllers.console.wraps import account_initialization_required | |||
| from core.llm.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \ | |||
| from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \ | |||
| LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError | |||
| from flask_restful import Resource | |||
| from services.audio_service import AudioService | |||
| @@ -17,7 +17,7 @@ from controllers.console.app.error import ConversationCompletedError, AppUnavail | |||
| from controllers.console.setup import setup_required | |||
| from controllers.console.wraps import account_initialization_required | |||
| from core.conversation_message_task import PubHandler | |||
| from core.llm.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \ | |||
| from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \ | |||
| LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError | |||
| from libs.helper import uuid_value | |||
| from flask_restful import Resource, reqparse | |||
| @@ -41,8 +41,11 @@ class CompletionMessageApi(Resource): | |||
| parser.add_argument('inputs', type=dict, required=True, location='json') | |||
| parser.add_argument('query', type=str, location='json') | |||
| parser.add_argument('model_config', type=dict, required=True, location='json') | |||
| parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json') | |||
| args = parser.parse_args() | |||
| streaming = args['response_mode'] != 'blocking' | |||
| account = flask_login.current_user | |||
| try: | |||
| @@ -51,7 +54,7 @@ class CompletionMessageApi(Resource): | |||
| user=account, | |||
| args=args, | |||
| from_source='console', | |||
| streaming=True, | |||
| streaming=streaming, | |||
| is_model_config_override=True | |||
| ) | |||
| @@ -111,8 +114,11 @@ class ChatMessageApi(Resource): | |||
| parser.add_argument('query', type=str, required=True, location='json') | |||
| parser.add_argument('model_config', type=dict, required=True, location='json') | |||
| parser.add_argument('conversation_id', type=uuid_value, location='json') | |||
| parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json') | |||
| args = parser.parse_args() | |||
| streaming = args['response_mode'] != 'blocking' | |||
| account = flask_login.current_user | |||
| try: | |||
| @@ -121,7 +127,7 @@ class ChatMessageApi(Resource): | |||
| user=account, | |||
| args=args, | |||
| from_source='console', | |||
| streaming=True, | |||
| streaming=streaming, | |||
| is_model_config_override=True | |||
| ) | |||
| @@ -7,7 +7,7 @@ from controllers.console.app.error import ProviderNotInitializeError, ProviderQu | |||
| from controllers.console.setup import setup_required | |||
| from controllers.console.wraps import account_initialization_required | |||
| from core.generator.llm_generator import LLMGenerator | |||
| from core.llm.error import ProviderTokenNotInitError, QuotaExceededError, LLMBadRequestError, LLMAPIConnectionError, \ | |||
| from core.model_providers.error import ProviderTokenNotInitError, QuotaExceededError, LLMBadRequestError, LLMAPIConnectionError, \ | |||
| LLMAPIUnavailableError, LLMRateLimitError, LLMAuthorizationError, ModelCurrentlyNotSupportError | |||
| @@ -14,7 +14,7 @@ from controllers.console.app.error import CompletionRequestError, ProviderNotIni | |||
| AppMoreLikeThisDisabledError, ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError | |||
| from controllers.console.setup import setup_required | |||
| from controllers.console.wraps import account_initialization_required | |||
| from core.llm.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \ | |||
| from core.model_providers.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \ | |||
| ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError | |||
| from libs.helper import uuid_value, TimestampField | |||
| from libs.infinite_scroll_pagination import InfiniteScrollPagination | |||
| @@ -28,9 +28,9 @@ class ModelConfigResource(Resource): | |||
| # validate config | |||
| model_configuration = AppModelConfigService.validate_configuration( | |||
| tenant_id=current_user.current_tenant_id, | |||
| account=current_user, | |||
| config=request.json, | |||
| mode=app_model.mode | |||
| config=request.json | |||
| ) | |||
| new_app_model_config = AppModelConfig( | |||
| @@ -255,7 +255,7 @@ class DataSourceNotionApi(Resource): | |||
| # validate args | |||
| DocumentService.estimate_args_validate(args) | |||
| indexing_runner = IndexingRunner() | |||
| response = indexing_runner.notion_indexing_estimate(args['notion_info_list'], args['process_rule']) | |||
| response = indexing_runner.notion_indexing_estimate(current_user.current_tenant_id, args['notion_info_list'], args['process_rule']) | |||
| return response, 200 | |||
| @@ -5,10 +5,13 @@ from flask_restful import Resource, reqparse, fields, marshal, marshal_with | |||
| from werkzeug.exceptions import NotFound, Forbidden | |||
| import services | |||
| from controllers.console import api | |||
| from controllers.console.app.error import ProviderNotInitializeError | |||
| from controllers.console.datasets.error import DatasetNameDuplicateError | |||
| from controllers.console.setup import setup_required | |||
| from controllers.console.wraps import account_initialization_required | |||
| from core.indexing_runner import IndexingRunner | |||
| from core.model_providers.error import LLMBadRequestError | |||
| from core.model_providers.model_factory import ModelFactory | |||
| from libs.helper import TimestampField | |||
| from extensions.ext_database import db | |||
| from models.dataset import DocumentSegment, Document | |||
| @@ -97,6 +100,15 @@ class DatasetListApi(Resource): | |||
| if current_user.current_tenant.current_role not in ['admin', 'owner']: | |||
| raise Forbidden() | |||
| try: | |||
| ModelFactory.get_embedding_model( | |||
| tenant_id=current_user.current_tenant_id | |||
| ) | |||
| except LLMBadRequestError: | |||
| raise ProviderNotInitializeError( | |||
| f"No Embedding Model available. Please configure a valid provider " | |||
| f"in the Settings -> Model Provider.") | |||
| try: | |||
| dataset = DatasetService.create_empty_dataset( | |||
| tenant_id=current_user.current_tenant_id, | |||
| @@ -235,12 +247,26 @@ class DatasetIndexingEstimateApi(Resource): | |||
| raise NotFound("File not found.") | |||
| indexing_runner = IndexingRunner() | |||
| response = indexing_runner.file_indexing_estimate(file_details, args['process_rule'], args['doc_form']) | |||
| try: | |||
| response = indexing_runner.file_indexing_estimate(current_user.current_tenant_id, file_details, | |||
| args['process_rule'], args['doc_form']) | |||
| except LLMBadRequestError: | |||
| raise ProviderNotInitializeError( | |||
| f"No Embedding Model available. Please configure a valid provider " | |||
| f"in the Settings -> Model Provider.") | |||
| elif args['info_list']['data_source_type'] == 'notion_import': | |||
| indexing_runner = IndexingRunner() | |||
| response = indexing_runner.notion_indexing_estimate(args['info_list']['notion_info_list'], | |||
| args['process_rule'], args['doc_form']) | |||
| try: | |||
| response = indexing_runner.notion_indexing_estimate(current_user.current_tenant_id, | |||
| args['info_list']['notion_info_list'], | |||
| args['process_rule'], args['doc_form']) | |||
| except LLMBadRequestError: | |||
| raise ProviderNotInitializeError( | |||
| f"No Embedding Model available. Please configure a valid provider " | |||
| f"in the Settings -> Model Provider.") | |||
| else: | |||
| raise ValueError('Data source type not support') | |||
| return response, 200 | |||
| @@ -18,7 +18,9 @@ from controllers.console.datasets.error import DocumentAlreadyFinishedError, Inv | |||
| from controllers.console.setup import setup_required | |||
| from controllers.console.wraps import account_initialization_required | |||
| from core.indexing_runner import IndexingRunner | |||
| from core.llm.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError | |||
| from core.model_providers.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError, \ | |||
| LLMBadRequestError | |||
| from core.model_providers.model_factory import ModelFactory | |||
| from extensions.ext_redis import redis_client | |||
| from libs.helper import TimestampField | |||
| from extensions.ext_database import db | |||
| @@ -280,6 +282,15 @@ class DatasetDocumentListApi(Resource): | |||
| # validate args | |||
| DocumentService.document_create_args_validate(args) | |||
| try: | |||
| ModelFactory.get_embedding_model( | |||
| tenant_id=current_user.current_tenant_id | |||
| ) | |||
| except LLMBadRequestError: | |||
| raise ProviderNotInitializeError( | |||
| f"No Embedding Model available. Please configure a valid provider " | |||
| f"in the Settings -> Model Provider.") | |||
| try: | |||
| documents, batch = DocumentService.save_document_with_dataset_id(dataset, args, current_user) | |||
| except ProviderTokenNotInitError as ex: | |||
| @@ -319,6 +330,15 @@ class DatasetInitApi(Resource): | |||
| parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json') | |||
| args = parser.parse_args() | |||
| try: | |||
| ModelFactory.get_embedding_model( | |||
| tenant_id=current_user.current_tenant_id | |||
| ) | |||
| except LLMBadRequestError: | |||
| raise ProviderNotInitializeError( | |||
| f"No Embedding Model available. Please configure a valid provider " | |||
| f"in the Settings -> Model Provider.") | |||
| # validate args | |||
| DocumentService.document_create_args_validate(args) | |||
| @@ -384,7 +404,13 @@ class DocumentIndexingEstimateApi(DocumentResource): | |||
| indexing_runner = IndexingRunner() | |||
| response = indexing_runner.file_indexing_estimate([file], data_process_rule_dict) | |||
| try: | |||
| response = indexing_runner.file_indexing_estimate(current_user.current_tenant_id, [file], | |||
| data_process_rule_dict) | |||
| except LLMBadRequestError: | |||
| raise ProviderNotInitializeError( | |||
| f"No Embedding Model available. Please configure a valid provider " | |||
| f"in the Settings -> Model Provider.") | |||
| return response | |||
| @@ -445,12 +471,24 @@ class DocumentBatchIndexingEstimateApi(DocumentResource): | |||
| raise NotFound("File not found.") | |||
| indexing_runner = IndexingRunner() | |||
| response = indexing_runner.file_indexing_estimate(file_details, data_process_rule_dict) | |||
| try: | |||
| response = indexing_runner.file_indexing_estimate(current_user.current_tenant_id, file_details, | |||
| data_process_rule_dict) | |||
| except LLMBadRequestError: | |||
| raise ProviderNotInitializeError( | |||
| f"No Embedding Model available. Please configure a valid provider " | |||
| f"in the Settings -> Model Provider.") | |||
| elif dataset.data_source_type: | |||
| indexing_runner = IndexingRunner() | |||
| response = indexing_runner.notion_indexing_estimate(info_list, | |||
| data_process_rule_dict) | |||
| try: | |||
| response = indexing_runner.notion_indexing_estimate(current_user.current_tenant_id, | |||
| info_list, | |||
| data_process_rule_dict) | |||
| except LLMBadRequestError: | |||
| raise ProviderNotInitializeError( | |||
| f"No Embedding Model available. Please configure a valid provider " | |||
| f"in the Settings -> Model Provider.") | |||
| else: | |||
| raise ValueError('Data source type not support') | |||
| return response | |||
| @@ -11,7 +11,7 @@ from controllers.console.app.error import ProviderNotInitializeError, ProviderQu | |||
| from controllers.console.datasets.error import HighQualityDatasetOnlyError, DatasetNotInitializedError | |||
| from controllers.console.setup import setup_required | |||
| from controllers.console.wraps import account_initialization_required | |||
| from core.llm.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError | |||
| from core.model_providers.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError | |||
| from libs.helper import TimestampField | |||
| from services.dataset_service import DatasetService | |||
| from services.hit_testing_service import HitTestingService | |||
| @@ -102,6 +102,8 @@ class HitTestingApi(Resource): | |||
| raise ProviderQuotaExceededError() | |||
| except ModelCurrentlyNotSupportError: | |||
| raise ProviderModelCurrentlyNotSupportError() | |||
| except ValueError as e: | |||
| raise ValueError(str(e)) | |||
| except Exception as e: | |||
| logging.exception("Hit testing failed.") | |||
| raise InternalServerError(str(e)) | |||
| @@ -11,7 +11,7 @@ from controllers.console.app.error import AppUnavailableError, ProviderNotInitia | |||
| NoAudioUploadedError, AudioTooLargeError, \ | |||
| UnsupportedAudioTypeError, ProviderNotSupportSpeechToTextError | |||
| from controllers.console.explore.wraps import InstalledAppResource | |||
| from core.llm.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \ | |||
| from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \ | |||
| LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError | |||
| from services.audio_service import AudioService | |||
| from services.errors.audio import NoAudioUploadedServiceError, AudioTooLargeServiceError, \ | |||
| @@ -15,7 +15,7 @@ from controllers.console.app.error import ConversationCompletedError, AppUnavail | |||
| from controllers.console.explore.error import NotCompletionAppError, NotChatAppError | |||
| from controllers.console.explore.wraps import InstalledAppResource | |||
| from core.conversation_message_task import PubHandler | |||
| from core.llm.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \ | |||
| from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \ | |||
| LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError | |||
| from libs.helper import uuid_value | |||
| from services.completion_service import CompletionService | |||
| @@ -15,7 +15,7 @@ from controllers.console.app.error import AppMoreLikeThisDisabledError, Provider | |||
| ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, CompletionRequestError | |||
| from controllers.console.explore.error import NotCompletionAppError, AppSuggestedQuestionsAfterAnswerDisabledError | |||
| from controllers.console.explore.wraps import InstalledAppResource | |||
| from core.llm.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \ | |||
| from core.model_providers.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \ | |||
| ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError | |||
| from libs.helper import uuid_value, TimestampField | |||
| from services.completion_service import CompletionService | |||
| @@ -4,8 +4,6 @@ from flask_restful import marshal_with, fields | |||
| from controllers.console import api | |||
| from controllers.console.explore.wraps import InstalledAppResource | |||
| from core.llm.llm_builder import LLMBuilder | |||
| from models.provider import ProviderName | |||
| from models.model import InstalledApp | |||
| @@ -35,13 +33,12 @@ class AppParameterApi(InstalledAppResource): | |||
| """Retrieve app parameters.""" | |||
| app_model = installed_app.app | |||
| app_model_config = app_model.app_model_config | |||
| provider_name = LLMBuilder.get_default_provider(installed_app.tenant_id, 'whisper-1') | |||
| return { | |||
| 'opening_statement': app_model_config.opening_statement, | |||
| 'suggested_questions': app_model_config.suggested_questions_list, | |||
| 'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict, | |||
| 'speech_to_text': app_model_config.speech_to_text_dict if provider_name == ProviderName.OPENAI.value else { 'enabled': False }, | |||
| 'speech_to_text': app_model_config.speech_to_text_dict, | |||
| 'more_like_this': app_model_config.more_like_this_dict, | |||
| 'user_input_form': app_model_config.user_input_form_list | |||
| } | |||
| @@ -11,7 +11,7 @@ from controllers.console.app.error import AppUnavailableError, ProviderNotInitia | |||
| NoAudioUploadedError, AudioTooLargeError, \ | |||
| UnsupportedAudioTypeError, ProviderNotSupportSpeechToTextError | |||
| from controllers.console.universal_chat.wraps import UniversalChatResource | |||
| from core.llm.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \ | |||
| from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \ | |||
| LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError | |||
| from services.audio_service import AudioService | |||
| from services.errors.audio import NoAudioUploadedServiceError, AudioTooLargeServiceError, \ | |||
| @@ -12,9 +12,8 @@ from controllers.console import api | |||
| from controllers.console.app.error import ConversationCompletedError, AppUnavailableError, ProviderNotInitializeError, \ | |||
| ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, CompletionRequestError | |||
| from controllers.console.universal_chat.wraps import UniversalChatResource | |||
| from core.constant import llm_constant | |||
| from core.conversation_message_task import PubHandler | |||
| from core.llm.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError, \ | |||
| from core.model_providers.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError, \ | |||
| LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, LLMRateLimitError, LLMAuthorizationError | |||
| from libs.helper import uuid_value | |||
| from services.completion_service import CompletionService | |||
| @@ -27,6 +26,7 @@ class UniversalChatApi(UniversalChatResource): | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('query', type=str, required=True, location='json') | |||
| parser.add_argument('conversation_id', type=uuid_value, location='json') | |||
| parser.add_argument('provider', type=str, required=True, location='json') | |||
| parser.add_argument('model', type=str, required=True, location='json') | |||
| parser.add_argument('tools', type=list, required=True, location='json') | |||
| args = parser.parse_args() | |||
| @@ -36,11 +36,7 @@ class UniversalChatApi(UniversalChatResource): | |||
| # update app model config | |||
| args['model_config'] = app_model_config.to_dict() | |||
| args['model_config']['model']['name'] = args['model'] | |||
| if not llm_constant.models[args['model']]: | |||
| raise ValueError("Model not exists.") | |||
| args['model_config']['model']['provider'] = llm_constant.models[args['model']] | |||
| args['model_config']['model']['provider'] = args['provider'] | |||
| args['model_config']['agent_mode']['tools'] = args['tools'] | |||
| if not args['model_config']['agent_mode']['tools']: | |||
| @@ -12,7 +12,7 @@ from controllers.console.app.error import ProviderNotInitializeError, \ | |||
| ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, CompletionRequestError | |||
| from controllers.console.explore.error import AppSuggestedQuestionsAfterAnswerDisabledError | |||
| from controllers.console.universal_chat.wraps import UniversalChatResource | |||
| from core.llm.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \ | |||
| from core.model_providers.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \ | |||
| ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError | |||
| from libs.helper import uuid_value, TimestampField | |||
| from services.errors.conversation import ConversationNotExistsError | |||
| @@ -4,8 +4,6 @@ from flask_restful import marshal_with, fields | |||
| from controllers.console import api | |||
| from controllers.console.universal_chat.wraps import UniversalChatResource | |||
| from core.llm.llm_builder import LLMBuilder | |||
| from models.provider import ProviderName | |||
| from models.model import App | |||
| @@ -23,13 +21,12 @@ class UniversalChatParameterApi(UniversalChatResource): | |||
| """Retrieve app parameters.""" | |||
| app_model = universal_app | |||
| app_model_config = app_model.app_model_config | |||
| provider_name = LLMBuilder.get_default_provider(universal_app.tenant_id, 'whisper-1') | |||
| return { | |||
| 'opening_statement': app_model_config.opening_statement, | |||
| 'suggested_questions': app_model_config.suggested_questions_list, | |||
| 'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict, | |||
| 'speech_to_text': app_model_config.speech_to_text_dict if provider_name == ProviderName.OPENAI.value else { 'enabled': False }, | |||
| 'speech_to_text': app_model_config.speech_to_text_dict, | |||
| } | |||
| @@ -0,0 +1,53 @@ | |||
| import logging | |||
| import stripe | |||
| from flask import request, current_app | |||
| from flask_restful import Resource | |||
| from controllers.console import api | |||
| from controllers.console.setup import setup_required | |||
| from controllers.console.wraps import only_edition_cloud | |||
| from services.provider_checkout_service import ProviderCheckoutService | |||
| class StripeWebhookApi(Resource): | |||
| @setup_required | |||
| @only_edition_cloud | |||
| def post(self): | |||
| payload = request.data | |||
| sig_header = request.headers.get('STRIPE_SIGNATURE') | |||
| webhook_secret = current_app.config.get('STRIPE_WEBHOOK_SECRET') | |||
| try: | |||
| event = stripe.Webhook.construct_event( | |||
| payload, sig_header, webhook_secret | |||
| ) | |||
| except ValueError as e: | |||
| # Invalid payload | |||
| return 'Invalid payload', 400 | |||
| except stripe.error.SignatureVerificationError as e: | |||
| # Invalid signature | |||
| return 'Invalid signature', 400 | |||
| # Handle the checkout.session.completed event | |||
| if event['type'] == 'checkout.session.completed': | |||
| logging.debug(event['data']['object']['id']) | |||
| logging.debug(event['data']['object']['amount_subtotal']) | |||
| logging.debug(event['data']['object']['currency']) | |||
| logging.debug(event['data']['object']['payment_intent']) | |||
| logging.debug(event['data']['object']['payment_status']) | |||
| logging.debug(event['data']['object']['metadata']) | |||
| # Fulfill the purchase... | |||
| provider_checkout_service = ProviderCheckoutService() | |||
| try: | |||
| provider_checkout_service.fulfill_provider_order(event) | |||
| except Exception as e: | |||
| logging.debug(str(e)) | |||
| return 'success', 200 | |||
| return 'success', 200 | |||
| api.add_resource(StripeWebhookApi, '/webhook/stripe') | |||
| @@ -1,24 +1,18 @@ | |||
| # -*- coding:utf-8 -*- | |||
| import base64 | |||
| import json | |||
| import logging | |||
| from flask import current_app | |||
| from flask_login import login_required, current_user | |||
| from flask_restful import Resource, reqparse, abort | |||
| from flask_restful import Resource, reqparse | |||
| from werkzeug.exceptions import Forbidden | |||
| from controllers.console import api | |||
| from controllers.console.app.error import ProviderNotInitializeError | |||
| from controllers.console.setup import setup_required | |||
| from controllers.console.wraps import account_initialization_required | |||
| from core.llm.provider.errors import ValidateFailedError | |||
| from extensions.ext_database import db | |||
| from libs import rsa | |||
| from models.provider import Provider, ProviderType, ProviderName | |||
| from core.model_providers.error import LLMBadRequestError | |||
| from core.model_providers.providers.base import CredentialsValidateFailedError | |||
| from services.provider_checkout_service import ProviderCheckoutService | |||
| from services.provider_service import ProviderService | |||
| class ProviderListApi(Resource): | |||
| class ModelProviderListApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @@ -26,156 +20,115 @@ class ProviderListApi(Resource): | |||
| def get(self): | |||
| tenant_id = current_user.current_tenant_id | |||
| """ | |||
| If the type is AZURE_OPENAI, decode and return the four fields of azure_api_type, azure_api_version:, | |||
| azure_api_base, azure_api_key as an object, where azure_api_key displays the first 6 bits in plaintext, and the | |||
| rest is replaced by * and the last two bits are displayed in plaintext | |||
| If the type is other, decode and return the Token field directly, the field displays the first 6 bits in | |||
| plaintext, the rest is replaced by * and the last two bits are displayed in plaintext | |||
| """ | |||
| ProviderService.init_supported_provider(current_user.current_tenant) | |||
| providers = Provider.query.filter_by(tenant_id=tenant_id).all() | |||
| provider_list = [ | |||
| { | |||
| 'provider_name': p.provider_name, | |||
| 'provider_type': p.provider_type, | |||
| 'is_valid': p.is_valid, | |||
| 'last_used': p.last_used, | |||
| 'is_enabled': p.is_enabled, | |||
| **({ | |||
| 'quota_type': p.quota_type, | |||
| 'quota_limit': p.quota_limit, | |||
| 'quota_used': p.quota_used | |||
| } if p.provider_type == ProviderType.SYSTEM.value else {}), | |||
| 'token': ProviderService.get_obfuscated_api_key(current_user.current_tenant, | |||
| ProviderName(p.provider_name), only_custom=True) | |||
| if p.provider_type == ProviderType.CUSTOM.value else None | |||
| } | |||
| for p in providers | |||
| ] | |||
| provider_service = ProviderService() | |||
| provider_list = provider_service.get_provider_list(tenant_id) | |||
| return provider_list | |||
| class ProviderTokenApi(Resource): | |||
| class ModelProviderValidateApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| def post(self, provider): | |||
| if provider not in [p.value for p in ProviderName]: | |||
| abort(404) | |||
| def post(self, provider_name: str): | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('config', type=dict, required=True, nullable=False, location='json') | |||
| args = parser.parse_args() | |||
| provider_service = ProviderService() | |||
| result = True | |||
| error = None | |||
| try: | |||
| provider_service.custom_provider_config_validate( | |||
| provider_name=provider_name, | |||
| config=args['config'] | |||
| ) | |||
| except CredentialsValidateFailedError as ex: | |||
| result = False | |||
| error = str(ex) | |||
| response = {'result': 'success' if result else 'error'} | |||
| if not result: | |||
| response['error'] = error | |||
| return response | |||
| class ModelProviderUpdateApi(Resource): | |||
| # The role of the current user in the ta table must be admin or owner | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| def post(self, provider_name: str): | |||
| if current_user.current_tenant.current_role not in ['admin', 'owner']: | |||
| logging.log(logging.ERROR, | |||
| f'User {current_user.id} is not authorized to update provider token, current_role is {current_user.current_tenant.current_role}') | |||
| raise Forbidden() | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('token', type=ProviderService.get_token_type( | |||
| tenant=current_user.current_tenant, | |||
| provider_name=ProviderName(provider) | |||
| ), required=True, nullable=False, location='json') | |||
| parser.add_argument('config', type=dict, required=True, nullable=False, location='json') | |||
| args = parser.parse_args() | |||
| if args['token']: | |||
| try: | |||
| ProviderService.validate_provider_configs( | |||
| tenant=current_user.current_tenant, | |||
| provider_name=ProviderName(provider), | |||
| configs=args['token'] | |||
| ) | |||
| token_is_valid = True | |||
| except ValidateFailedError as ex: | |||
| raise ValueError(str(ex)) | |||
| base64_encrypted_token = ProviderService.get_encrypted_token( | |||
| tenant=current_user.current_tenant, | |||
| provider_name=ProviderName(provider), | |||
| configs=args['token'] | |||
| provider_service = ProviderService() | |||
| try: | |||
| provider_service.save_custom_provider_config( | |||
| tenant_id=current_user.current_tenant_id, | |||
| provider_name=provider_name, | |||
| config=args['config'] | |||
| ) | |||
| else: | |||
| base64_encrypted_token = None | |||
| token_is_valid = False | |||
| tenant = current_user.current_tenant | |||
| provider_model = db.session.query(Provider).filter( | |||
| Provider.tenant_id == tenant.id, | |||
| Provider.provider_name == provider, | |||
| Provider.provider_type == ProviderType.CUSTOM.value | |||
| ).first() | |||
| # Only allow updating token for CUSTOM provider type | |||
| if provider_model: | |||
| provider_model.encrypted_config = base64_encrypted_token | |||
| provider_model.is_valid = token_is_valid | |||
| else: | |||
| provider_model = Provider(tenant_id=tenant.id, provider_name=provider, | |||
| provider_type=ProviderType.CUSTOM.value, | |||
| encrypted_config=base64_encrypted_token, | |||
| is_valid=token_is_valid) | |||
| db.session.add(provider_model) | |||
| if provider in [ProviderName.OPENAI.value, ProviderName.AZURE_OPENAI.value] and provider_model.is_valid: | |||
| other_providers = db.session.query(Provider).filter( | |||
| Provider.tenant_id == tenant.id, | |||
| Provider.provider_name.in_([ProviderName.OPENAI.value, ProviderName.AZURE_OPENAI.value]), | |||
| Provider.provider_name != provider, | |||
| Provider.provider_type == ProviderType.CUSTOM.value | |||
| ).all() | |||
| for other_provider in other_providers: | |||
| other_provider.is_valid = False | |||
| db.session.commit() | |||
| if provider in [ProviderName.AZURE_OPENAI.value, ProviderName.COHERE.value, | |||
| ProviderName.HUGGINGFACEHUB.value]: | |||
| return {'result': 'success', 'warning': 'MOCK: This provider is not supported yet.'}, 201 | |||
| except CredentialsValidateFailedError as ex: | |||
| raise ValueError(str(ex)) | |||
| return {'result': 'success'}, 201 | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| def delete(self, provider_name: str): | |||
| if current_user.current_tenant.current_role not in ['admin', 'owner']: | |||
| raise Forbidden() | |||
| class ProviderTokenValidateApi(Resource): | |||
| provider_service = ProviderService() | |||
| provider_service.delete_custom_provider( | |||
| tenant_id=current_user.current_tenant_id, | |||
| provider_name=provider_name | |||
| ) | |||
| return {'result': 'success'}, 204 | |||
| class ModelProviderModelValidateApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| def post(self, provider): | |||
| if provider not in [p.value for p in ProviderName]: | |||
| abort(404) | |||
| def post(self, provider_name: str): | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('token', type=ProviderService.get_token_type( | |||
| tenant=current_user.current_tenant, | |||
| provider_name=ProviderName(provider) | |||
| ), required=True, nullable=False, location='json') | |||
| parser.add_argument('model_name', type=str, required=True, nullable=False, location='json') | |||
| parser.add_argument('model_type', type=str, required=True, nullable=False, | |||
| choices=['text-generation', 'embeddings', 'speech2text'], location='json') | |||
| parser.add_argument('config', type=dict, required=True, nullable=False, location='json') | |||
| args = parser.parse_args() | |||
| # todo: remove this when the provider is supported | |||
| if provider in [ProviderName.COHERE.value, | |||
| ProviderName.HUGGINGFACEHUB.value]: | |||
| return {'result': 'success', 'warning': 'MOCK: This provider is not supported yet.'} | |||
| provider_service = ProviderService() | |||
| result = True | |||
| error = None | |||
| try: | |||
| ProviderService.validate_provider_configs( | |||
| tenant=current_user.current_tenant, | |||
| provider_name=ProviderName(provider), | |||
| configs=args['token'] | |||
| provider_service.custom_provider_model_config_validate( | |||
| provider_name=provider_name, | |||
| model_name=args['model_name'], | |||
| model_type=args['model_type'], | |||
| config=args['config'] | |||
| ) | |||
| except ValidateFailedError as e: | |||
| except CredentialsValidateFailedError as ex: | |||
| result = False | |||
| error = str(e) | |||
| error = str(ex) | |||
| response = {'result': 'success' if result else 'error'} | |||
| @@ -185,91 +138,148 @@ class ProviderTokenValidateApi(Resource): | |||
| return response | |||
| class ProviderSystemApi(Resource): | |||
| class ModelProviderModelUpdateApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| def put(self, provider): | |||
| if provider not in [p.value for p in ProviderName]: | |||
| abort(404) | |||
| def post(self, provider_name: str): | |||
| if current_user.current_tenant.current_role not in ['admin', 'owner']: | |||
| raise Forbidden() | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('is_enabled', type=bool, required=True, location='json') | |||
| parser.add_argument('model_name', type=str, required=True, nullable=False, location='json') | |||
| parser.add_argument('model_type', type=str, required=True, nullable=False, | |||
| choices=['text-generation', 'embeddings', 'speech2text'], location='json') | |||
| parser.add_argument('config', type=dict, required=True, nullable=False, location='json') | |||
| args = parser.parse_args() | |||
| tenant = current_user.current_tenant_id | |||
| provider_model = Provider.query.filter_by(tenant_id=tenant.id, provider_name=provider).first() | |||
| if provider_model and provider_model.provider_type == ProviderType.SYSTEM.value: | |||
| provider_model.is_valid = args['is_enabled'] | |||
| db.session.commit() | |||
| elif not provider_model: | |||
| if provider == ProviderName.OPENAI.value: | |||
| quota_limit = current_app.config['OPENAI_HOSTED_QUOTA_LIMIT'] | |||
| elif provider == ProviderName.ANTHROPIC.value: | |||
| quota_limit = current_app.config['ANTHROPIC_HOSTED_QUOTA_LIMIT'] | |||
| else: | |||
| quota_limit = 0 | |||
| ProviderService.create_system_provider( | |||
| tenant, | |||
| provider, | |||
| quota_limit, | |||
| args['is_enabled'] | |||
| provider_service = ProviderService() | |||
| try: | |||
| provider_service.add_or_save_custom_provider_model_config( | |||
| tenant_id=current_user.current_tenant_id, | |||
| provider_name=provider_name, | |||
| model_name=args['model_name'], | |||
| model_type=args['model_type'], | |||
| config=args['config'] | |||
| ) | |||
| else: | |||
| abort(403) | |||
| except CredentialsValidateFailedError as ex: | |||
| raise ValueError(str(ex)) | |||
| return {'result': 'success'} | |||
| return {'result': 'success'}, 200 | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| def get(self, provider): | |||
| if provider not in [p.value for p in ProviderName]: | |||
| abort(404) | |||
| def delete(self, provider_name: str): | |||
| if current_user.current_tenant.current_role not in ['admin', 'owner']: | |||
| raise Forbidden() | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('model_name', type=str, required=True, nullable=False, location='args') | |||
| parser.add_argument('model_type', type=str, required=True, nullable=False, | |||
| choices=['text-generation', 'embeddings', 'speech2text'], location='args') | |||
| args = parser.parse_args() | |||
| # The role of the current user in the ta table must be admin or owner | |||
| provider_service = ProviderService() | |||
| provider_service.delete_custom_provider_model( | |||
| tenant_id=current_user.current_tenant_id, | |||
| provider_name=provider_name, | |||
| model_name=args['model_name'], | |||
| model_type=args['model_type'] | |||
| ) | |||
| return {'result': 'success'}, 204 | |||
| class PreferredProviderTypeUpdateApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| def post(self, provider_name: str): | |||
| if current_user.current_tenant.current_role not in ['admin', 'owner']: | |||
| raise Forbidden() | |||
| provider_model = db.session.query(Provider).filter(Provider.tenant_id == current_user.current_tenant_id, | |||
| Provider.provider_name == provider, | |||
| Provider.provider_type == ProviderType.SYSTEM.value).first() | |||
| system_model = None | |||
| if provider_model: | |||
| system_model = { | |||
| 'result': 'success', | |||
| 'provider': { | |||
| 'provider_name': provider_model.provider_name, | |||
| 'provider_type': provider_model.provider_type, | |||
| 'is_valid': provider_model.is_valid, | |||
| 'last_used': provider_model.last_used, | |||
| 'is_enabled': provider_model.is_enabled, | |||
| 'quota_type': provider_model.quota_type, | |||
| 'quota_limit': provider_model.quota_limit, | |||
| 'quota_used': provider_model.quota_used | |||
| } | |||
| } | |||
| else: | |||
| abort(404) | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('preferred_provider_type', type=str, required=True, nullable=False, | |||
| choices=['system', 'custom'], location='json') | |||
| args = parser.parse_args() | |||
| provider_service = ProviderService() | |||
| provider_service.switch_preferred_provider( | |||
| tenant_id=current_user.current_tenant_id, | |||
| provider_name=provider_name, | |||
| preferred_provider_type=args['preferred_provider_type'] | |||
| ) | |||
| return {'result': 'success'} | |||
| class ModelProviderModelParameterRuleApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| def get(self, provider_name: str): | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('model_name', type=str, required=True, nullable=False, location='args') | |||
| args = parser.parse_args() | |||
| return system_model | |||
| provider_service = ProviderService() | |||
| try: | |||
| parameter_rules = provider_service.get_model_parameter_rules( | |||
| tenant_id=current_user.current_tenant_id, | |||
| model_provider_name=provider_name, | |||
| model_name=args['model_name'], | |||
| model_type='text-generation' | |||
| ) | |||
| except LLMBadRequestError: | |||
| raise ProviderNotInitializeError( | |||
| f"Current Text Generation Model is invalid. Please switch to the available model.") | |||
| rules = { | |||
| k: { | |||
| 'enabled': v.enabled, | |||
| 'min': v.min, | |||
| 'max': v.max, | |||
| 'default': v.default | |||
| } | |||
| for k, v in vars(parameter_rules).items() | |||
| } | |||
| api.add_resource(ProviderTokenApi, '/providers/<provider>/token', | |||
| endpoint='current_providers_token') # Deprecated | |||
| api.add_resource(ProviderTokenValidateApi, '/providers/<provider>/token-validate', | |||
| endpoint='current_providers_token_validate') # Deprecated | |||
| return rules | |||
| api.add_resource(ProviderTokenApi, '/workspaces/current/providers/<provider>/token', | |||
| endpoint='workspaces_current_providers_token') # PUT for updating provider token | |||
| api.add_resource(ProviderTokenValidateApi, '/workspaces/current/providers/<provider>/token-validate', | |||
| endpoint='workspaces_current_providers_token_validate') # POST for validating provider token | |||
| api.add_resource(ProviderListApi, '/workspaces/current/providers') # GET for getting providers list | |||
| api.add_resource(ProviderSystemApi, '/workspaces/current/providers/<provider>/system', | |||
| endpoint='workspaces_current_providers_system') # GET for getting provider quota, PUT for updating provider status | |||
| class ModelProviderPaymentCheckoutUrlApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| def get(self, provider_name: str): | |||
| provider_service = ProviderCheckoutService() | |||
| provider_checkout = provider_service.create_checkout( | |||
| tenant_id=current_user.current_tenant_id, | |||
| provider_name=provider_name, | |||
| account=current_user | |||
| ) | |||
| return { | |||
| 'url': provider_checkout.get_checkout_url() | |||
| } | |||
| api.add_resource(ModelProviderListApi, '/workspaces/current/model-providers') | |||
| api.add_resource(ModelProviderValidateApi, '/workspaces/current/model-providers/<string:provider_name>/validate') | |||
| api.add_resource(ModelProviderUpdateApi, '/workspaces/current/model-providers/<string:provider_name>') | |||
| api.add_resource(ModelProviderModelValidateApi, | |||
| '/workspaces/current/model-providers/<string:provider_name>/models/validate') | |||
| api.add_resource(ModelProviderModelUpdateApi, | |||
| '/workspaces/current/model-providers/<string:provider_name>/models') | |||
| api.add_resource(PreferredProviderTypeUpdateApi, | |||
| '/workspaces/current/model-providers/<string:provider_name>/preferred-provider-type') | |||
| api.add_resource(ModelProviderModelParameterRuleApi, | |||
| '/workspaces/current/model-providers/<string:provider_name>/models/parameter-rules') | |||
| api.add_resource(ModelProviderPaymentCheckoutUrlApi, | |||
| '/workspaces/current/model-providers/<string:provider_name>/checkout-url') | |||
| @@ -0,0 +1,108 @@ | |||
| from flask_login import login_required, current_user | |||
| from flask_restful import Resource, reqparse | |||
| from controllers.console import api | |||
| from controllers.console.setup import setup_required | |||
| from controllers.console.wraps import account_initialization_required | |||
| from core.model_providers.model_provider_factory import ModelProviderFactory | |||
| from core.model_providers.models.entity.model_params import ModelType | |||
| from models.provider import ProviderType | |||
| from services.provider_service import ProviderService | |||
| class DefaultModelApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| def get(self): | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('model_type', type=str, required=True, nullable=False, | |||
| choices=['text-generation', 'embeddings', 'speech2text'], location='args') | |||
| args = parser.parse_args() | |||
| tenant_id = current_user.current_tenant_id | |||
| provider_service = ProviderService() | |||
| default_model = provider_service.get_default_model_of_model_type( | |||
| tenant_id=tenant_id, | |||
| model_type=args['model_type'] | |||
| ) | |||
| if not default_model: | |||
| return None | |||
| model_provider = ModelProviderFactory.get_preferred_model_provider( | |||
| tenant_id, | |||
| default_model.provider_name | |||
| ) | |||
| if not model_provider: | |||
| return { | |||
| 'model_name': default_model.model_name, | |||
| 'model_type': default_model.model_type, | |||
| 'model_provider': { | |||
| 'provider_name': default_model.provider_name | |||
| } | |||
| } | |||
| provider = model_provider.provider | |||
| rst = { | |||
| 'model_name': default_model.model_name, | |||
| 'model_type': default_model.model_type, | |||
| 'model_provider': { | |||
| 'provider_name': provider.provider_name, | |||
| 'provider_type': provider.provider_type | |||
| } | |||
| } | |||
| model_provider_rules = ModelProviderFactory.get_provider_rule(default_model.provider_name) | |||
| if provider.provider_type == ProviderType.SYSTEM.value: | |||
| rst['model_provider']['quota_type'] = provider.quota_type | |||
| rst['model_provider']['quota_unit'] = model_provider_rules['system_config']['quota_unit'] | |||
| rst['model_provider']['quota_limit'] = provider.quota_limit | |||
| rst['model_provider']['quota_used'] = provider.quota_used | |||
| return rst | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| def post(self): | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('model_name', type=str, required=True, nullable=False, location='json') | |||
| parser.add_argument('model_type', type=str, required=True, nullable=False, | |||
| choices=['text-generation', 'embeddings', 'speech2text'], location='json') | |||
| parser.add_argument('provider_name', type=str, required=True, nullable=False, location='json') | |||
| args = parser.parse_args() | |||
| provider_service = ProviderService() | |||
| provider_service.update_default_model_of_model_type( | |||
| tenant_id=current_user.current_tenant_id, | |||
| model_type=args['model_type'], | |||
| provider_name=args['provider_name'], | |||
| model_name=args['model_name'] | |||
| ) | |||
| return {'result': 'success'} | |||
| class ValidModelApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| def get(self, model_type): | |||
| ModelType.value_of(model_type) | |||
| provider_service = ProviderService() | |||
| valid_models = provider_service.get_valid_model_list( | |||
| tenant_id=current_user.current_tenant_id, | |||
| model_type=model_type | |||
| ) | |||
| return valid_models | |||
| api.add_resource(DefaultModelApi, '/workspaces/current/default-model') | |||
| api.add_resource(ValidModelApi, '/workspaces/current/models/model-type/<string:model_type>') | |||
| @@ -0,0 +1,130 @@ | |||
| # -*- coding:utf-8 -*- | |||
| from flask_login import login_required, current_user | |||
| from flask_restful import Resource, reqparse | |||
| from werkzeug.exceptions import Forbidden | |||
| from controllers.console import api | |||
| from controllers.console.setup import setup_required | |||
| from controllers.console.wraps import account_initialization_required | |||
| from core.model_providers.providers.base import CredentialsValidateFailedError | |||
| from models.provider import ProviderType | |||
| from services.provider_service import ProviderService | |||
| class ProviderListApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| def get(self): | |||
| tenant_id = current_user.current_tenant_id | |||
| """ | |||
| If the type is AZURE_OPENAI, decode and return the four fields of azure_api_type, azure_api_version:, | |||
| azure_api_base, azure_api_key as an object, where azure_api_key displays the first 6 bits in plaintext, and the | |||
| rest is replaced by * and the last two bits are displayed in plaintext | |||
| If the type is other, decode and return the Token field directly, the field displays the first 6 bits in | |||
| plaintext, the rest is replaced by * and the last two bits are displayed in plaintext | |||
| """ | |||
| provider_service = ProviderService() | |||
| provider_info_list = provider_service.get_provider_list(tenant_id) | |||
| provider_list = [ | |||
| { | |||
| 'provider_name': p['provider_name'], | |||
| 'provider_type': p['provider_type'], | |||
| 'is_valid': p['is_valid'], | |||
| 'last_used': p['last_used'], | |||
| 'is_enabled': p['is_valid'], | |||
| **({ | |||
| 'quota_type': p['quota_type'], | |||
| 'quota_limit': p['quota_limit'], | |||
| 'quota_used': p['quota_used'] | |||
| } if p['provider_type'] == ProviderType.SYSTEM.value else {}), | |||
| 'token': (p['config'] if p['provider_name'] != 'openai' else p['config']['openai_api_key']) | |||
| if p['config'] else None | |||
| } | |||
| for name, provider_info in provider_info_list.items() | |||
| for p in provider_info['providers'] | |||
| ] | |||
| return provider_list | |||
| class ProviderTokenApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| def post(self, provider): | |||
| # The role of the current user in the ta table must be admin or owner | |||
| if current_user.current_tenant.current_role not in ['admin', 'owner']: | |||
| raise Forbidden() | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('token', required=True, nullable=False, location='json') | |||
| args = parser.parse_args() | |||
| if provider == 'openai': | |||
| args['token'] = { | |||
| 'openai_api_key': args['token'] | |||
| } | |||
| provider_service = ProviderService() | |||
| try: | |||
| provider_service.save_custom_provider_config( | |||
| tenant_id=current_user.current_tenant_id, | |||
| provider_name=provider, | |||
| config=args['token'] | |||
| ) | |||
| except CredentialsValidateFailedError as ex: | |||
| raise ValueError(str(ex)) | |||
| return {'result': 'success'}, 201 | |||
| class ProviderTokenValidateApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| def post(self, provider): | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('token', required=True, nullable=False, location='json') | |||
| args = parser.parse_args() | |||
| provider_service = ProviderService() | |||
| if provider == 'openai': | |||
| args['token'] = { | |||
| 'openai_api_key': args['token'] | |||
| } | |||
| result = True | |||
| error = None | |||
| try: | |||
| provider_service.custom_provider_config_validate( | |||
| provider_name=provider, | |||
| config=args['token'] | |||
| ) | |||
| except CredentialsValidateFailedError as ex: | |||
| result = False | |||
| error = str(ex) | |||
| response = {'result': 'success' if result else 'error'} | |||
| if not result: | |||
| response['error'] = error | |||
| return response | |||
| api.add_resource(ProviderTokenApi, '/workspaces/current/providers/<provider>/token', | |||
| endpoint='workspaces_current_providers_token') # PUT for updating provider token | |||
| api.add_resource(ProviderTokenValidateApi, '/workspaces/current/providers/<provider>/token-validate', | |||
| endpoint='workspaces_current_providers_token_validate') # POST for validating provider token | |||
| api.add_resource(ProviderListApi, '/workspaces/current/providers') # GET for getting providers list | |||
| @@ -30,7 +30,7 @@ tenant_fields = { | |||
| 'created_at': TimestampField, | |||
| 'role': fields.String, | |||
| 'providers': fields.List(fields.Nested(provider_fields)), | |||
| 'in_trail': fields.Boolean, | |||
| 'in_trial': fields.Boolean, | |||
| 'trial_end_reason': fields.String, | |||
| } | |||
| @@ -4,8 +4,6 @@ from flask_restful import fields, marshal_with | |||
| from controllers.service_api import api | |||
| from controllers.service_api.wraps import AppApiResource | |||
| from core.llm.llm_builder import LLMBuilder | |||
| from models.provider import ProviderName | |||
| from models.model import App | |||
| @@ -35,13 +33,12 @@ class AppParameterApi(AppApiResource): | |||
| def get(self, app_model: App, end_user): | |||
| """Retrieve app parameters.""" | |||
| app_model_config = app_model.app_model_config | |||
| provider_name = LLMBuilder.get_default_provider(app_model.tenant_id, 'whisper-1') | |||
| return { | |||
| 'opening_statement': app_model_config.opening_statement, | |||
| 'suggested_questions': app_model_config.suggested_questions_list, | |||
| 'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict, | |||
| 'speech_to_text': app_model_config.speech_to_text_dict if provider_name == ProviderName.OPENAI.value else { 'enabled': False }, | |||
| 'speech_to_text': app_model_config.speech_to_text_dict, | |||
| 'more_like_this': app_model_config.more_like_this_dict, | |||
| 'user_input_form': app_model_config.user_input_form_list | |||
| } | |||
| @@ -9,7 +9,7 @@ from controllers.service_api.app.error import AppUnavailableError, ProviderNotIn | |||
| ProviderModelCurrentlyNotSupportError, NoAudioUploadedError, AudioTooLargeError, UnsupportedAudioTypeError, \ | |||
| ProviderNotSupportSpeechToTextError | |||
| from controllers.service_api.wraps import AppApiResource | |||
| from core.llm.error import LLMBadRequestError, LLMAuthorizationError, LLMAPIUnavailableError, LLMAPIConnectionError, \ | |||
| from core.model_providers.error import LLMBadRequestError, LLMAuthorizationError, LLMAPIUnavailableError, LLMAPIConnectionError, \ | |||
| LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError | |||
| from models.model import App, AppModelConfig | |||
| from services.audio_service import AudioService | |||
| @@ -14,7 +14,7 @@ from controllers.service_api.app.error import AppUnavailableError, ProviderNotIn | |||
| ProviderModelCurrentlyNotSupportError | |||
| from controllers.service_api.wraps import AppApiResource | |||
| from core.conversation_message_task import PubHandler | |||
| from core.llm.error import LLMBadRequestError, LLMAuthorizationError, LLMAPIUnavailableError, LLMAPIConnectionError, \ | |||
| from core.model_providers.error import LLMBadRequestError, LLMAuthorizationError, LLMAPIUnavailableError, LLMAPIConnectionError, \ | |||
| LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError | |||
| from libs.helper import uuid_value | |||
| from services.completion_service import CompletionService | |||
| @@ -11,7 +11,7 @@ from controllers.service_api.app.error import ProviderNotInitializeError | |||
| from controllers.service_api.dataset.error import ArchivedDocumentImmutableError, DocumentIndexingError, \ | |||
| DatasetNotInitedError | |||
| from controllers.service_api.wraps import DatasetApiResource | |||
| from core.llm.error import ProviderTokenNotInitError | |||
| from core.model_providers.error import ProviderTokenNotInitError | |||
| from extensions.ext_database import db | |||
| from extensions.ext_storage import storage | |||
| from models.model import UploadFile | |||
| @@ -4,8 +4,6 @@ from flask_restful import marshal_with, fields | |||
| from controllers.web import api | |||
| from controllers.web.wraps import WebApiResource | |||
| from core.llm.llm_builder import LLMBuilder | |||
| from models.provider import ProviderName | |||
| from models.model import App | |||
| @@ -34,13 +32,12 @@ class AppParameterApi(WebApiResource): | |||
| def get(self, app_model: App, end_user): | |||
| """Retrieve app parameters.""" | |||
| app_model_config = app_model.app_model_config | |||
| provider_name = LLMBuilder.get_default_provider(app_model.tenant_id, 'whisper-1') | |||
| return { | |||
| 'opening_statement': app_model_config.opening_statement, | |||
| 'suggested_questions': app_model_config.suggested_questions_list, | |||
| 'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict, | |||
| 'speech_to_text': app_model_config.speech_to_text_dict if provider_name == ProviderName.OPENAI.value else { 'enabled': False }, | |||
| 'speech_to_text': app_model_config.speech_to_text_dict, | |||
| 'more_like_this': app_model_config.more_like_this_dict, | |||
| 'user_input_form': app_model_config.user_input_form_list | |||
| } | |||
| @@ -10,7 +10,7 @@ from controllers.web.error import AppUnavailableError, ProviderNotInitializeErro | |||
| ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, NoAudioUploadedError, AudioTooLargeError, \ | |||
| UnsupportedAudioTypeError, ProviderNotSupportSpeechToTextError | |||
| from controllers.web.wraps import WebApiResource | |||
| from core.llm.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \ | |||
| from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \ | |||
| LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError | |||
| from services.audio_service import AudioService | |||
| from services.errors.audio import NoAudioUploadedServiceError, AudioTooLargeServiceError, \ | |||
| @@ -14,7 +14,7 @@ from controllers.web.error import AppUnavailableError, ConversationCompletedErro | |||
| ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError | |||
| from controllers.web.wraps import WebApiResource | |||
| from core.conversation_message_task import PubHandler | |||
| from core.llm.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \ | |||
| from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \ | |||
| LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError | |||
| from libs.helper import uuid_value | |||
| from services.completion_service import CompletionService | |||
| @@ -14,7 +14,7 @@ from controllers.web.error import NotChatAppError, CompletionRequestError, Provi | |||
| AppMoreLikeThisDisabledError, NotCompletionAppError, AppSuggestedQuestionsAfterAnswerDisabledError, \ | |||
| ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError | |||
| from controllers.web.wraps import WebApiResource | |||
| from core.llm.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \ | |||
| from core.model_providers.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \ | |||
| ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError | |||
| from libs.helper import uuid_value, TimestampField | |||
| from services.completion_service import CompletionService | |||
| @@ -1,36 +0,0 @@ | |||
| import os | |||
| from typing import Optional | |||
| import langchain | |||
| from flask import Flask | |||
| from pydantic import BaseModel | |||
| from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler | |||
| from core.prompt.prompt_template import OneLineFormatter | |||
| class HostedOpenAICredential(BaseModel): | |||
| api_key: str | |||
| class HostedAnthropicCredential(BaseModel): | |||
| api_key: str | |||
| class HostedLLMCredentials(BaseModel): | |||
| openai: Optional[HostedOpenAICredential] = None | |||
| anthropic: Optional[HostedAnthropicCredential] = None | |||
| hosted_llm_credentials = HostedLLMCredentials() | |||
| def init_app(app: Flask): | |||
| if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true': | |||
| langchain.verbose = True | |||
| if app.config.get("OPENAI_API_KEY"): | |||
| hosted_llm_credentials.openai = HostedOpenAICredential(api_key=app.config.get("OPENAI_API_KEY")) | |||
| if app.config.get("ANTHROPIC_API_KEY"): | |||
| hosted_llm_credentials.anthropic = HostedAnthropicCredential(api_key=app.config.get("ANTHROPIC_API_KEY")) | |||
| @@ -1,20 +1,17 @@ | |||
| from typing import cast, List | |||
| from typing import List | |||
| from langchain import OpenAI | |||
| from langchain.base_language import BaseLanguageModel | |||
| from langchain.chat_models.openai import ChatOpenAI | |||
| from langchain.schema import BaseMessage | |||
| from core.constant import llm_constant | |||
| from core.model_providers.models.entity.message import to_prompt_messages | |||
| from core.model_providers.models.llm.base import BaseLLM | |||
| class CalcTokenMixin: | |||
| def get_num_tokens_from_messages(self, llm: BaseLanguageModel, messages: List[BaseMessage], **kwargs) -> int: | |||
| llm = cast(ChatOpenAI, llm) | |||
| return llm.get_num_tokens_from_messages(messages) | |||
| def get_num_tokens_from_messages(self, model_instance: BaseLLM, messages: List[BaseMessage], **kwargs) -> int: | |||
| return model_instance.get_num_tokens(to_prompt_messages(messages)) | |||
| def get_message_rest_tokens(self, llm: BaseLanguageModel, messages: List[BaseMessage], **kwargs) -> int: | |||
| def get_message_rest_tokens(self, model_instance: BaseLLM, messages: List[BaseMessage], **kwargs) -> int: | |||
| """ | |||
| Got the rest tokens available for the model after excluding messages tokens and completion max tokens | |||
| @@ -22,10 +19,9 @@ class CalcTokenMixin: | |||
| :param messages: | |||
| :return: | |||
| """ | |||
| llm = cast(ChatOpenAI, llm) | |||
| llm_max_tokens = llm_constant.max_context_token_length[llm.model_name] | |||
| completion_max_tokens = llm.max_tokens | |||
| used_tokens = self.get_num_tokens_from_messages(llm, messages, **kwargs) | |||
| llm_max_tokens = model_instance.model_rules.max_tokens.max | |||
| completion_max_tokens = model_instance.model_kwargs.max_tokens | |||
| used_tokens = self.get_num_tokens_from_messages(model_instance, messages, **kwargs) | |||
| rest_tokens = llm_max_tokens - completion_max_tokens - used_tokens | |||
| return rest_tokens | |||
| @@ -4,9 +4,11 @@ from langchain.agents import OpenAIFunctionsAgent, BaseSingleActionAgent | |||
| from langchain.callbacks.base import BaseCallbackManager | |||
| from langchain.callbacks.manager import Callbacks | |||
| from langchain.prompts.chat import BaseMessagePromptTemplate | |||
| from langchain.schema import AgentAction, AgentFinish, BaseLanguageModel, SystemMessage | |||
| from langchain.schema import AgentAction, AgentFinish, SystemMessage | |||
| from langchain.schema.language_model import BaseLanguageModel | |||
| from langchain.tools import BaseTool | |||
| from core.model_providers.models.llm.base import BaseLLM | |||
| from core.tool.dataset_retriever_tool import DatasetRetrieverTool | |||
| @@ -14,6 +16,12 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent): | |||
| """ | |||
| An Multi Dataset Retrieve Agent driven by Router. | |||
| """ | |||
| model_instance: BaseLLM | |||
| class Config: | |||
| """Configuration for this pydantic object.""" | |||
| arbitrary_types_allowed = True | |||
| def should_use_agent(self, query: str): | |||
| """ | |||
| @@ -6,7 +6,8 @@ from langchain.agents.openai_functions_agent.base import _parse_ai_message, \ | |||
| from langchain.callbacks.base import BaseCallbackManager | |||
| from langchain.callbacks.manager import Callbacks | |||
| from langchain.prompts.chat import BaseMessagePromptTemplate | |||
| from langchain.schema import AgentAction, AgentFinish, SystemMessage, BaseLanguageModel | |||
| from langchain.schema import AgentAction, AgentFinish, SystemMessage | |||
| from langchain.schema.language_model import BaseLanguageModel | |||
| from langchain.tools import BaseTool | |||
| from core.agent.agent.calc_token_mixin import ExceededLLMTokensLimitError | |||
| @@ -84,7 +85,7 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, OpenAIFunctio | |||
| # summarize messages if rest_tokens < 0 | |||
| try: | |||
| messages = self.summarize_messages_if_needed(self.llm, messages, functions=self.functions) | |||
| messages = self.summarize_messages_if_needed(messages, functions=self.functions) | |||
| except ExceededLLMTokensLimitError as e: | |||
| return AgentFinish(return_values={"output": str(e)}, log=str(e)) | |||
| @@ -3,20 +3,28 @@ from typing import cast, List | |||
| from langchain.chat_models import ChatOpenAI | |||
| from langchain.chat_models.openai import _convert_message_to_dict | |||
| from langchain.memory.summary import SummarizerMixin | |||
| from langchain.schema import SystemMessage, HumanMessage, BaseMessage, AIMessage, BaseLanguageModel | |||
| from langchain.schema import SystemMessage, HumanMessage, BaseMessage, AIMessage | |||
| from langchain.schema.language_model import BaseLanguageModel | |||
| from pydantic import BaseModel | |||
| from core.agent.agent.calc_token_mixin import ExceededLLMTokensLimitError, CalcTokenMixin | |||
| from core.model_providers.models.llm.base import BaseLLM | |||
| class OpenAIFunctionCallSummarizeMixin(BaseModel, CalcTokenMixin): | |||
| moving_summary_buffer: str = "" | |||
| moving_summary_index: int = 0 | |||
| summary_llm: BaseLanguageModel | |||
| model_instance: BaseLLM | |||
| def summarize_messages_if_needed(self, llm: BaseLanguageModel, messages: List[BaseMessage], **kwargs) -> List[BaseMessage]: | |||
| class Config: | |||
| """Configuration for this pydantic object.""" | |||
| arbitrary_types_allowed = True | |||
| def summarize_messages_if_needed(self, messages: List[BaseMessage], **kwargs) -> List[BaseMessage]: | |||
| # calculate rest tokens and summarize previous function observation messages if rest_tokens < 0 | |||
| rest_tokens = self.get_message_rest_tokens(llm, messages, **kwargs) | |||
| rest_tokens = self.get_message_rest_tokens(self.model_instance, messages, **kwargs) | |||
| rest_tokens = rest_tokens - 20 # to deal with the inaccuracy of rest_tokens | |||
| if rest_tokens >= 0: | |||
| return messages | |||
| @@ -6,7 +6,8 @@ from langchain.agents.openai_functions_multi_agent.base import OpenAIMultiFuncti | |||
| from langchain.callbacks.base import BaseCallbackManager | |||
| from langchain.callbacks.manager import Callbacks | |||
| from langchain.prompts.chat import BaseMessagePromptTemplate | |||
| from langchain.schema import AgentAction, AgentFinish, SystemMessage, BaseLanguageModel | |||
| from langchain.schema import AgentAction, AgentFinish, SystemMessage | |||
| from langchain.schema.language_model import BaseLanguageModel | |||
| from langchain.tools import BaseTool | |||
| from core.agent.agent.calc_token_mixin import ExceededLLMTokensLimitError | |||
| @@ -84,7 +85,7 @@ class AutoSummarizingOpenMultiAIFunctionCallAgent(OpenAIMultiFunctionsAgent, Ope | |||
| # summarize messages if rest_tokens < 0 | |||
| try: | |||
| messages = self.summarize_messages_if_needed(self.llm, messages, functions=self.functions) | |||
| messages = self.summarize_messages_if_needed(messages, functions=self.functions) | |||
| except ExceededLLMTokensLimitError as e: | |||
| return AgentFinish(return_values={"output": str(e)}, log=str(e)) | |||
| @@ -0,0 +1,162 @@ | |||
| import re | |||
| from typing import List, Tuple, Any, Union, Sequence, Optional, cast | |||
| from langchain import BasePromptTemplate | |||
| from langchain.agents import StructuredChatAgent, AgentOutputParser, Agent | |||
| from langchain.agents.structured_chat.base import HUMAN_MESSAGE_TEMPLATE | |||
| from langchain.base_language import BaseLanguageModel | |||
| from langchain.callbacks.base import BaseCallbackManager | |||
| from langchain.callbacks.manager import Callbacks | |||
| from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate | |||
| from langchain.schema import AgentAction, AgentFinish, OutputParserException | |||
| from langchain.tools import BaseTool | |||
| from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX | |||
| from core.model_providers.models.llm.base import BaseLLM | |||
| from core.tool.dataset_retriever_tool import DatasetRetrieverTool | |||
| FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input). | |||
| The nouns in the format of "Thought", "Action", "Action Input", "Final Answer" must be expressed in English. | |||
| Valid "action" values: "Final Answer" or {tool_names} | |||
| Provide only ONE action per $JSON_BLOB, as shown: | |||
| ``` | |||
| {{{{ | |||
| "action": $TOOL_NAME, | |||
| "action_input": $INPUT | |||
| }}}} | |||
| ``` | |||
| Follow this format: | |||
| Question: input question to answer | |||
| Thought: consider previous and subsequent steps | |||
| Action: | |||
| ``` | |||
| $JSON_BLOB | |||
| ``` | |||
| Observation: action result | |||
| ... (repeat Thought/Action/Observation N times) | |||
| Thought: I know what to respond | |||
| Action: | |||
| ``` | |||
| {{{{ | |||
| "action": "Final Answer", | |||
| "action_input": "Final response to human" | |||
| }}}} | |||
| ```""" | |||
| class StructuredMultiDatasetRouterAgent(StructuredChatAgent): | |||
| model_instance: BaseLLM | |||
| dataset_tools: Sequence[BaseTool] | |||
| class Config: | |||
| """Configuration for this pydantic object.""" | |||
| arbitrary_types_allowed = True | |||
| def should_use_agent(self, query: str): | |||
| """ | |||
| return should use agent | |||
| Using the ReACT mode to determine whether an agent is needed is costly, | |||
| so it's better to just use an Agent for reasoning, which is cheaper. | |||
| :param query: | |||
| :return: | |||
| """ | |||
| return True | |||
| def plan( | |||
| self, | |||
| intermediate_steps: List[Tuple[AgentAction, str]], | |||
| callbacks: Callbacks = None, | |||
| **kwargs: Any, | |||
| ) -> Union[AgentAction, AgentFinish]: | |||
| """Given input, decided what to do. | |||
| Args: | |||
| intermediate_steps: Steps the LLM has taken to date, | |||
| along with observations | |||
| callbacks: Callbacks to run. | |||
| **kwargs: User inputs. | |||
| Returns: | |||
| Action specifying what tool to use. | |||
| """ | |||
| if len(self.dataset_tools) == 0: | |||
| return AgentFinish(return_values={"output": ''}, log='') | |||
| elif len(self.dataset_tools) == 1: | |||
| tool = next(iter(self.dataset_tools)) | |||
| tool = cast(DatasetRetrieverTool, tool) | |||
| rst = tool.run(tool_input={'dataset_id': tool.dataset_id, 'query': kwargs['input']}) | |||
| return AgentFinish(return_values={"output": rst}, log=rst) | |||
| full_inputs = self.get_full_inputs(intermediate_steps, **kwargs) | |||
| full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs) | |||
| try: | |||
| return self.output_parser.parse(full_output) | |||
| except OutputParserException: | |||
| return AgentFinish({"output": "I'm sorry, the answer of model is invalid, " | |||
| "I don't know how to respond to that."}, "") | |||
| @classmethod | |||
| def create_prompt( | |||
| cls, | |||
| tools: Sequence[BaseTool], | |||
| prefix: str = PREFIX, | |||
| suffix: str = SUFFIX, | |||
| human_message_template: str = HUMAN_MESSAGE_TEMPLATE, | |||
| format_instructions: str = FORMAT_INSTRUCTIONS, | |||
| input_variables: Optional[List[str]] = None, | |||
| memory_prompts: Optional[List[BasePromptTemplate]] = None, | |||
| ) -> BasePromptTemplate: | |||
| tool_strings = [] | |||
| for tool in tools: | |||
| args_schema = re.sub("}", "}}}}", re.sub("{", "{{{{", str(tool.args))) | |||
| tool_strings.append(f"{tool.name}: {tool.description}, args: {args_schema}") | |||
| formatted_tools = "\n".join(tool_strings) | |||
| unique_tool_names = set(tool.name for tool in tools) | |||
| tool_names = ", ".join('"' + name + '"' for name in unique_tool_names) | |||
| format_instructions = format_instructions.format(tool_names=tool_names) | |||
| template = "\n\n".join([prefix, formatted_tools, format_instructions, suffix]) | |||
| if input_variables is None: | |||
| input_variables = ["input", "agent_scratchpad"] | |||
| _memory_prompts = memory_prompts or [] | |||
| messages = [ | |||
| SystemMessagePromptTemplate.from_template(template), | |||
| *_memory_prompts, | |||
| HumanMessagePromptTemplate.from_template(human_message_template), | |||
| ] | |||
| return ChatPromptTemplate(input_variables=input_variables, messages=messages) | |||
| @classmethod | |||
| def from_llm_and_tools( | |||
| cls, | |||
| llm: BaseLanguageModel, | |||
| tools: Sequence[BaseTool], | |||
| callback_manager: Optional[BaseCallbackManager] = None, | |||
| output_parser: Optional[AgentOutputParser] = None, | |||
| prefix: str = PREFIX, | |||
| suffix: str = SUFFIX, | |||
| human_message_template: str = HUMAN_MESSAGE_TEMPLATE, | |||
| format_instructions: str = FORMAT_INSTRUCTIONS, | |||
| input_variables: Optional[List[str]] = None, | |||
| memory_prompts: Optional[List[BasePromptTemplate]] = None, | |||
| **kwargs: Any, | |||
| ) -> Agent: | |||
| return super().from_llm_and_tools( | |||
| llm=llm, | |||
| tools=tools, | |||
| callback_manager=callback_manager, | |||
| output_parser=output_parser, | |||
| prefix=prefix, | |||
| suffix=suffix, | |||
| human_message_template=human_message_template, | |||
| format_instructions=format_instructions, | |||
| input_variables=input_variables, | |||
| memory_prompts=memory_prompts, | |||
| dataset_tools=tools, | |||
| **kwargs, | |||
| ) | |||
| @@ -14,7 +14,7 @@ from langchain.tools import BaseTool | |||
| from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX | |||
| from core.agent.agent.calc_token_mixin import CalcTokenMixin, ExceededLLMTokensLimitError | |||
| from core.model_providers.models.llm.base import BaseLLM | |||
| FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input). | |||
| The nouns in the format of "Thought", "Action", "Action Input", "Final Answer" must be expressed in English. | |||
| @@ -53,6 +53,12 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin): | |||
| moving_summary_buffer: str = "" | |||
| moving_summary_index: int = 0 | |||
| summary_llm: BaseLanguageModel | |||
| model_instance: BaseLLM | |||
| class Config: | |||
| """Configuration for this pydantic object.""" | |||
| arbitrary_types_allowed = True | |||
| def should_use_agent(self, query: str): | |||
| """ | |||
| @@ -89,7 +95,7 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin): | |||
| if prompts: | |||
| messages = prompts[0].to_messages() | |||
| rest_tokens = self.get_message_rest_tokens(self.llm_chain.llm, messages) | |||
| rest_tokens = self.get_message_rest_tokens(self.model_instance, messages) | |||
| if rest_tokens < 0: | |||
| full_inputs = self.summarize_messages(intermediate_steps, **kwargs) | |||
| @@ -3,7 +3,6 @@ import logging | |||
| from typing import Union, Optional | |||
| from langchain.agents import BaseSingleActionAgent, BaseMultiActionAgent | |||
| from langchain.base_language import BaseLanguageModel | |||
| from langchain.callbacks.manager import Callbacks | |||
| from langchain.memory.chat_memory import BaseChatMemory | |||
| from langchain.tools import BaseTool | |||
| @@ -13,14 +12,17 @@ from core.agent.agent.multi_dataset_router_agent import MultiDatasetRouterAgent | |||
| from core.agent.agent.openai_function_call import AutoSummarizingOpenAIFunctionCallAgent | |||
| from core.agent.agent.openai_multi_function_call import AutoSummarizingOpenMultiAIFunctionCallAgent | |||
| from core.agent.agent.output_parser.structured_chat import StructuredChatOutputParser | |||
| from core.agent.agent.structed_multi_dataset_router_agent import StructuredMultiDatasetRouterAgent | |||
| from core.agent.agent.structured_chat import AutoSummarizingStructuredChatAgent | |||
| from langchain.agents import AgentExecutor as LCAgentExecutor | |||
| from core.model_providers.models.llm.base import BaseLLM | |||
| from core.tool.dataset_retriever_tool import DatasetRetrieverTool | |||
| class PlanningStrategy(str, enum.Enum): | |||
| ROUTER = 'router' | |||
| REACT_ROUTER = 'react_router' | |||
| REACT = 'react' | |||
| FUNCTION_CALL = 'function_call' | |||
| MULTI_FUNCTION_CALL = 'multi_function_call' | |||
| @@ -28,10 +30,9 @@ class PlanningStrategy(str, enum.Enum): | |||
| class AgentConfiguration(BaseModel): | |||
| strategy: PlanningStrategy | |||
| llm: BaseLanguageModel | |||
| model_instance: BaseLLM | |||
| tools: list[BaseTool] | |||
| summary_llm: BaseLanguageModel | |||
| dataset_llm: BaseLanguageModel | |||
| summary_model_instance: BaseLLM | |||
| memory: Optional[BaseChatMemory] = None | |||
| callbacks: Callbacks = None | |||
| max_iterations: int = 6 | |||
| @@ -60,36 +61,49 @@ class AgentExecutor: | |||
| def _init_agent(self) -> Union[BaseSingleActionAgent | BaseMultiActionAgent]: | |||
| if self.configuration.strategy == PlanningStrategy.REACT: | |||
| agent = AutoSummarizingStructuredChatAgent.from_llm_and_tools( | |||
| llm=self.configuration.llm, | |||
| model_instance=self.configuration.model_instance, | |||
| llm=self.configuration.model_instance.client, | |||
| tools=self.configuration.tools, | |||
| output_parser=StructuredChatOutputParser(), | |||
| summary_llm=self.configuration.summary_llm, | |||
| summary_llm=self.configuration.summary_model_instance.client, | |||
| verbose=True | |||
| ) | |||
| elif self.configuration.strategy == PlanningStrategy.FUNCTION_CALL: | |||
| agent = AutoSummarizingOpenAIFunctionCallAgent.from_llm_and_tools( | |||
| llm=self.configuration.llm, | |||
| model_instance=self.configuration.model_instance, | |||
| llm=self.configuration.model_instance.client, | |||
| tools=self.configuration.tools, | |||
| extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None, # used for read chat histories memory | |||
| summary_llm=self.configuration.summary_llm, | |||
| summary_llm=self.configuration.summary_model_instance.client, | |||
| verbose=True | |||
| ) | |||
| elif self.configuration.strategy == PlanningStrategy.MULTI_FUNCTION_CALL: | |||
| agent = AutoSummarizingOpenMultiAIFunctionCallAgent.from_llm_and_tools( | |||
| llm=self.configuration.llm, | |||
| model_instance=self.configuration.model_instance, | |||
| llm=self.configuration.model_instance.client, | |||
| tools=self.configuration.tools, | |||
| extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None, # used for read chat histories memory | |||
| summary_llm=self.configuration.summary_llm, | |||
| summary_llm=self.configuration.summary_model_instance.client, | |||
| verbose=True | |||
| ) | |||
| elif self.configuration.strategy == PlanningStrategy.ROUTER: | |||
| self.configuration.tools = [t for t in self.configuration.tools if isinstance(t, DatasetRetrieverTool)] | |||
| agent = MultiDatasetRouterAgent.from_llm_and_tools( | |||
| llm=self.configuration.dataset_llm, | |||
| model_instance=self.configuration.model_instance, | |||
| llm=self.configuration.model_instance.client, | |||
| tools=self.configuration.tools, | |||
| extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None, | |||
| verbose=True | |||
| ) | |||
| elif self.configuration.strategy == PlanningStrategy.REACT_ROUTER: | |||
| self.configuration.tools = [t for t in self.configuration.tools if isinstance(t, DatasetRetrieverTool)] | |||
| agent = StructuredMultiDatasetRouterAgent.from_llm_and_tools( | |||
| model_instance=self.configuration.model_instance, | |||
| llm=self.configuration.model_instance.client, | |||
| tools=self.configuration.tools, | |||
| output_parser=StructuredChatOutputParser(), | |||
| verbose=True | |||
| ) | |||
| else: | |||
| raise NotImplementedError(f"Unknown Agent Strategy: {self.configuration.strategy}") | |||
| @@ -10,15 +10,16 @@ from langchain.schema import AgentAction, AgentFinish, LLMResult, ChatGeneration | |||
| from core.callback_handler.entity.agent_loop import AgentLoop | |||
| from core.conversation_message_task import ConversationMessageTask | |||
| from core.model_providers.models.llm.base import BaseLLM | |||
| class AgentLoopGatherCallbackHandler(BaseCallbackHandler): | |||
| """Callback Handler that prints to std out.""" | |||
| raise_error: bool = True | |||
| def __init__(self, model_name, conversation_message_task: ConversationMessageTask) -> None: | |||
| def __init__(self, model_instant: BaseLLM, conversation_message_task: ConversationMessageTask) -> None: | |||
| """Initialize callback handler.""" | |||
| self.model_name = model_name | |||
| self.model_instant = model_instant | |||
| self.conversation_message_task = conversation_message_task | |||
| self._agent_loops = [] | |||
| self._current_loop = None | |||
| @@ -152,7 +153,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler): | |||
| self._current_loop.latency = self._current_loop.completed_at - self._current_loop.started_at | |||
| self.conversation_message_task.on_agent_end( | |||
| self._message_agent_thought, self.model_name, self._current_loop | |||
| self._message_agent_thought, self.model_instant, self._current_loop | |||
| ) | |||
| self._agent_loops.append(self._current_loop) | |||
| @@ -183,7 +184,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler): | |||
| ) | |||
| self.conversation_message_task.on_agent_end( | |||
| self._message_agent_thought, self.model_name, self._current_loop | |||
| self._message_agent_thought, self.model_instant, self._current_loop | |||
| ) | |||
| self._agent_loops.append(self._current_loop) | |||
| @@ -3,18 +3,20 @@ import time | |||
| from typing import Any, Dict, List, Union | |||
| from langchain.callbacks.base import BaseCallbackHandler | |||
| from langchain.schema import LLMResult, BaseMessage, BaseLanguageModel | |||
| from langchain.schema import LLMResult, BaseMessage | |||
| from core.callback_handler.entity.llm_message import LLMMessage | |||
| from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException | |||
| from core.model_providers.models.entity.message import to_prompt_messages, PromptMessage | |||
| from core.model_providers.models.llm.base import BaseLLM | |||
| class LLMCallbackHandler(BaseCallbackHandler): | |||
| raise_error: bool = True | |||
| def __init__(self, llm: BaseLanguageModel, | |||
| def __init__(self, model_instance: BaseLLM, | |||
| conversation_message_task: ConversationMessageTask): | |||
| self.llm = llm | |||
| self.model_instance = model_instance | |||
| self.llm_message = LLMMessage() | |||
| self.start_at = None | |||
| self.conversation_message_task = conversation_message_task | |||
| @@ -46,7 +48,7 @@ class LLMCallbackHandler(BaseCallbackHandler): | |||
| }) | |||
| self.llm_message.prompt = real_prompts | |||
| self.llm_message.prompt_tokens = self.llm.get_num_tokens_from_messages(messages[0]) | |||
| self.llm_message.prompt_tokens = self.model_instance.get_num_tokens(to_prompt_messages(messages[0])) | |||
| def on_llm_start( | |||
| self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any | |||
| @@ -58,7 +60,7 @@ class LLMCallbackHandler(BaseCallbackHandler): | |||
| "text": prompts[0] | |||
| }] | |||
| self.llm_message.prompt_tokens = self.llm.get_num_tokens(prompts[0]) | |||
| self.llm_message.prompt_tokens = self.model_instance.get_num_tokens([PromptMessage(content=prompts[0])]) | |||
| def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: | |||
| end_at = time.perf_counter() | |||
| @@ -68,7 +70,7 @@ class LLMCallbackHandler(BaseCallbackHandler): | |||
| self.conversation_message_task.append_message_text(response.generations[0][0].text) | |||
| self.llm_message.completion = response.generations[0][0].text | |||
| self.llm_message.completion_tokens = self.llm.get_num_tokens(self.llm_message.completion) | |||
| self.llm_message.completion_tokens = self.model_instance.get_num_tokens([PromptMessage(content=self.llm_message.completion)]) | |||
| self.conversation_message_task.save_message(self.llm_message) | |||
| @@ -89,7 +91,9 @@ class LLMCallbackHandler(BaseCallbackHandler): | |||
| if self.conversation_message_task.streaming: | |||
| end_at = time.perf_counter() | |||
| self.llm_message.latency = end_at - self.start_at | |||
| self.llm_message.completion_tokens = self.llm.get_num_tokens(self.llm_message.completion) | |||
| self.llm_message.completion_tokens = self.model_instance.get_num_tokens( | |||
| [PromptMessage(content=self.llm_message.completion)] | |||
| ) | |||
| self.conversation_message_task.save_message(llm_message=self.llm_message, by_stopped=True) | |||
| else: | |||
| logging.error(error) | |||
| @@ -5,9 +5,7 @@ from typing import Any, Dict, Union | |||
| from langchain.callbacks.base import BaseCallbackHandler | |||
| from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler | |||
| from core.callback_handler.entity.chain_result import ChainResult | |||
| from core.constant import llm_constant | |||
| from core.conversation_message_task import ConversationMessageTask | |||
| @@ -2,27 +2,19 @@ import logging | |||
| import re | |||
| from typing import Optional, List, Union, Tuple | |||
| from langchain.base_language import BaseLanguageModel | |||
| from langchain.callbacks.base import BaseCallbackHandler | |||
| from langchain.chat_models.base import BaseChatModel | |||
| from langchain.llms import BaseLLM | |||
| from langchain.schema import BaseMessage, HumanMessage | |||
| from langchain.schema import BaseMessage | |||
| from requests.exceptions import ChunkedEncodingError | |||
| from core.agent.agent_executor import AgentExecuteResult, PlanningStrategy | |||
| from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler | |||
| from core.constant import llm_constant | |||
| from core.callback_handler.llm_callback_handler import LLMCallbackHandler | |||
| from core.callback_handler.std_out_callback_handler import DifyStreamingStdOutCallbackHandler, \ | |||
| DifyStdOutCallbackHandler | |||
| from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException | |||
| from core.llm.error import LLMBadRequestError | |||
| from core.llm.fake import FakeLLM | |||
| from core.llm.llm_builder import LLMBuilder | |||
| from core.llm.streamable_chat_open_ai import StreamableChatOpenAI | |||
| from core.llm.streamable_open_ai import StreamableOpenAI | |||
| from core.model_providers.error import LLMBadRequestError | |||
| from core.memory.read_only_conversation_token_db_buffer_shared_memory import \ | |||
| ReadOnlyConversationTokenDBBufferSharedMemory | |||
| from core.model_providers.model_factory import ModelFactory | |||
| from core.model_providers.models.entity.message import PromptMessage, to_prompt_messages | |||
| from core.model_providers.models.llm.base import BaseLLM | |||
| from core.orchestrator_rule_parser import OrchestratorRuleParser | |||
| from core.prompt.prompt_builder import PromptBuilder | |||
| from core.prompt.prompt_template import JinjaPromptTemplate | |||
| @@ -51,12 +43,10 @@ class Completion: | |||
| inputs = conversation.inputs | |||
| rest_tokens_for_context_and_memory = cls.get_validate_rest_tokens( | |||
| mode=app.mode, | |||
| final_model_instance = ModelFactory.get_text_generation_model_from_model_config( | |||
| tenant_id=app.tenant_id, | |||
| app_model_config=app_model_config, | |||
| query=query, | |||
| inputs=inputs | |||
| model_config=app_model_config.model_dict, | |||
| streaming=streaming | |||
| ) | |||
| conversation_message_task = ConversationMessageTask( | |||
| @@ -68,10 +58,17 @@ class Completion: | |||
| is_override=is_override, | |||
| inputs=inputs, | |||
| query=query, | |||
| streaming=streaming | |||
| streaming=streaming, | |||
| model_instance=final_model_instance | |||
| ) | |||
| chain_callback = MainChainGatherCallbackHandler(conversation_message_task) | |||
| rest_tokens_for_context_and_memory = cls.get_validate_rest_tokens( | |||
| mode=app.mode, | |||
| model_instance=final_model_instance, | |||
| app_model_config=app_model_config, | |||
| query=query, | |||
| inputs=inputs | |||
| ) | |||
| # init orchestrator rule parser | |||
| orchestrator_rule_parser = OrchestratorRuleParser( | |||
| @@ -80,6 +77,7 @@ class Completion: | |||
| ) | |||
| # parse sensitive_word_avoidance_chain | |||
| chain_callback = MainChainGatherCallbackHandler(conversation_message_task) | |||
| sensitive_word_avoidance_chain = orchestrator_rule_parser.to_sensitive_word_avoidance_chain([chain_callback]) | |||
| if sensitive_word_avoidance_chain: | |||
| query = sensitive_word_avoidance_chain.run(query) | |||
| @@ -102,15 +100,14 @@ class Completion: | |||
| # run the final llm | |||
| try: | |||
| cls.run_final_llm( | |||
| tenant_id=app.tenant_id, | |||
| model_instance=final_model_instance, | |||
| mode=app.mode, | |||
| app_model_config=app_model_config, | |||
| query=query, | |||
| inputs=inputs, | |||
| agent_execute_result=agent_execute_result, | |||
| conversation_message_task=conversation_message_task, | |||
| memory=memory, | |||
| streaming=streaming | |||
| memory=memory | |||
| ) | |||
| except ConversationTaskStoppedException: | |||
| return | |||
| @@ -121,31 +118,20 @@ class Completion: | |||
| return | |||
| @classmethod | |||
| def run_final_llm(cls, tenant_id: str, mode: str, app_model_config: AppModelConfig, query: str, inputs: dict, | |||
| def run_final_llm(cls, model_instance: BaseLLM, mode: str, app_model_config: AppModelConfig, query: str, inputs: dict, | |||
| agent_execute_result: Optional[AgentExecuteResult], | |||
| conversation_message_task: ConversationMessageTask, | |||
| memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory], streaming: bool): | |||
| memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory]): | |||
| # When no extra pre prompt is specified, | |||
| # the output of the agent can be used directly as the main output content without calling LLM again | |||
| fake_response = None | |||
| if not app_model_config.pre_prompt and agent_execute_result and agent_execute_result.output \ | |||
| and agent_execute_result.strategy != PlanningStrategy.ROUTER: | |||
| final_llm = FakeLLM(response=agent_execute_result.output, | |||
| origin_llm=agent_execute_result.configuration.llm, | |||
| streaming=streaming) | |||
| final_llm.callbacks = cls.get_llm_callbacks(final_llm, streaming, conversation_message_task) | |||
| response = final_llm.generate([[HumanMessage(content=query)]]) | |||
| return response | |||
| final_llm = LLMBuilder.to_llm_from_model( | |||
| tenant_id=tenant_id, | |||
| model=app_model_config.model_dict, | |||
| streaming=streaming | |||
| ) | |||
| fake_response = agent_execute_result.output | |||
| # get llm prompt | |||
| prompt, stop_words = cls.get_main_llm_prompt( | |||
| prompt_messages, stop_words = cls.get_main_llm_prompt( | |||
| mode=mode, | |||
| llm=final_llm, | |||
| model=app_model_config.model_dict, | |||
| pre_prompt=app_model_config.pre_prompt, | |||
| query=query, | |||
| @@ -154,25 +140,26 @@ class Completion: | |||
| memory=memory | |||
| ) | |||
| final_llm.callbacks = cls.get_llm_callbacks(final_llm, streaming, conversation_message_task) | |||
| cls.recale_llm_max_tokens( | |||
| final_llm=final_llm, | |||
| model=app_model_config.model_dict, | |||
| prompt=prompt, | |||
| mode=mode | |||
| model_instance=model_instance, | |||
| prompt_messages=prompt_messages, | |||
| ) | |||
| response = final_llm.generate([prompt], stop_words) | |||
| response = model_instance.run( | |||
| messages=prompt_messages, | |||
| stop=stop_words, | |||
| callbacks=[LLMCallbackHandler(model_instance, conversation_message_task)], | |||
| fake_response=fake_response | |||
| ) | |||
| return response | |||
| @classmethod | |||
| def get_main_llm_prompt(cls, mode: str, llm: BaseLanguageModel, model: dict, | |||
| def get_main_llm_prompt(cls, mode: str, model: dict, | |||
| pre_prompt: str, query: str, inputs: dict, | |||
| agent_execute_result: Optional[AgentExecuteResult], | |||
| memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory]) -> \ | |||
| Tuple[Union[str | List[BaseMessage]], Optional[List[str]]]: | |||
| Tuple[List[PromptMessage], Optional[List[str]]]: | |||
| if mode == 'completion': | |||
| prompt_template = JinjaPromptTemplate.from_template( | |||
| template=("""Use the following context as your learned knowledge, inside <context></context> XML tags. | |||
| @@ -200,11 +187,7 @@ And answer according to the language of the user's question. | |||
| **prompt_inputs | |||
| ) | |||
| if isinstance(llm, BaseChatModel): | |||
| # use chat llm as completion model | |||
| return [HumanMessage(content=prompt_content)], None | |||
| else: | |||
| return prompt_content, None | |||
| return [PromptMessage(content=prompt_content)], None | |||
| else: | |||
| messages: List[BaseMessage] = [] | |||
| @@ -249,12 +232,14 @@ And answer according to the language of the user's question. | |||
| inputs=human_inputs | |||
| ) | |||
| curr_message_tokens = memory.llm.get_num_tokens_from_messages([tmp_human_message]) | |||
| model_name = model['name'] | |||
| max_tokens = model.get("completion_params").get('max_tokens') | |||
| rest_tokens = llm_constant.max_context_token_length[model_name] \ | |||
| - max_tokens - curr_message_tokens | |||
| rest_tokens = max(rest_tokens, 0) | |||
| if memory.model_instance.model_rules.max_tokens.max: | |||
| curr_message_tokens = memory.model_instance.get_num_tokens(to_prompt_messages([tmp_human_message])) | |||
| max_tokens = model.get("completion_params").get('max_tokens') | |||
| rest_tokens = memory.model_instance.model_rules.max_tokens.max - max_tokens - curr_message_tokens | |||
| rest_tokens = max(rest_tokens, 0) | |||
| else: | |||
| rest_tokens = 2000 | |||
| histories = cls.get_history_messages_from_memory(memory, rest_tokens) | |||
| human_message_prompt += "\n\n" if human_message_prompt else "" | |||
| human_message_prompt += "Here is the chat histories between human and assistant, " \ | |||
| @@ -274,17 +259,7 @@ And answer according to the language of the user's question. | |||
| for message in messages: | |||
| message.content = re.sub(r'<\|.*?\|>', '', message.content) | |||
| return messages, ['\nHuman:', '</histories>'] | |||
| @classmethod | |||
| def get_llm_callbacks(cls, llm: BaseLanguageModel, | |||
| streaming: bool, | |||
| conversation_message_task: ConversationMessageTask) -> List[BaseCallbackHandler]: | |||
| llm_callback_handler = LLMCallbackHandler(llm, conversation_message_task) | |||
| if streaming: | |||
| return [llm_callback_handler, DifyStreamingStdOutCallbackHandler()] | |||
| else: | |||
| return [llm_callback_handler, DifyStdOutCallbackHandler()] | |||
| return to_prompt_messages(messages), ['\nHuman:', '</histories>'] | |||
| @classmethod | |||
| def get_history_messages_from_memory(cls, memory: ReadOnlyConversationTokenDBBufferSharedMemory, | |||
| @@ -300,15 +275,15 @@ And answer according to the language of the user's question. | |||
| conversation: Conversation, | |||
| **kwargs) -> ReadOnlyConversationTokenDBBufferSharedMemory: | |||
| # only for calc token in memory | |||
| memory_llm = LLMBuilder.to_llm_from_model( | |||
| memory_model_instance = ModelFactory.get_text_generation_model_from_model_config( | |||
| tenant_id=tenant_id, | |||
| model=app_model_config.model_dict | |||
| model_config=app_model_config.model_dict | |||
| ) | |||
| # use llm config from conversation | |||
| memory = ReadOnlyConversationTokenDBBufferSharedMemory( | |||
| conversation=conversation, | |||
| llm=memory_llm, | |||
| model_instance=memory_model_instance, | |||
| max_token_limit=kwargs.get("max_token_limit", 2048), | |||
| memory_key=kwargs.get("memory_key", "chat_history"), | |||
| return_messages=kwargs.get("return_messages", True), | |||
| @@ -320,21 +295,20 @@ And answer according to the language of the user's question. | |||
| return memory | |||
| @classmethod | |||
| def get_validate_rest_tokens(cls, mode: str, tenant_id: str, app_model_config: AppModelConfig, | |||
| def get_validate_rest_tokens(cls, mode: str, model_instance: BaseLLM, app_model_config: AppModelConfig, | |||
| query: str, inputs: dict) -> int: | |||
| llm = LLMBuilder.to_llm_from_model( | |||
| tenant_id=tenant_id, | |||
| model=app_model_config.model_dict | |||
| ) | |||
| model_limited_tokens = model_instance.model_rules.max_tokens.max | |||
| max_tokens = model_instance.get_model_kwargs().max_tokens | |||
| model_name = app_model_config.model_dict.get("name") | |||
| model_limited_tokens = llm_constant.max_context_token_length[model_name] | |||
| max_tokens = app_model_config.model_dict.get("completion_params").get('max_tokens') | |||
| if model_limited_tokens is None: | |||
| return -1 | |||
| if max_tokens is None: | |||
| max_tokens = 0 | |||
| # get prompt without memory and context | |||
| prompt, _ = cls.get_main_llm_prompt( | |||
| prompt_messages, _ = cls.get_main_llm_prompt( | |||
| mode=mode, | |||
| llm=llm, | |||
| model=app_model_config.model_dict, | |||
| pre_prompt=app_model_config.pre_prompt, | |||
| query=query, | |||
| @@ -343,9 +317,7 @@ And answer according to the language of the user's question. | |||
| memory=None | |||
| ) | |||
| prompt_tokens = llm.get_num_tokens(prompt) if isinstance(prompt, str) \ | |||
| else llm.get_num_tokens_from_messages(prompt) | |||
| prompt_tokens = model_instance.get_num_tokens(prompt_messages) | |||
| rest_tokens = model_limited_tokens - max_tokens - prompt_tokens | |||
| if rest_tokens < 0: | |||
| raise LLMBadRequestError("Query or prefix prompt is too long, you can reduce the prefix prompt, " | |||
| @@ -354,36 +326,40 @@ And answer according to the language of the user's question. | |||
| return rest_tokens | |||
| @classmethod | |||
| def recale_llm_max_tokens(cls, final_llm: BaseLanguageModel, model: dict, | |||
| prompt: Union[str, List[BaseMessage]], mode: str): | |||
| def recale_llm_max_tokens(cls, model_instance: BaseLLM, prompt_messages: List[PromptMessage]): | |||
| # recalc max_tokens if sum(prompt_token + max_tokens) over model token limit | |||
| model_name = model.get("name") | |||
| model_limited_tokens = llm_constant.max_context_token_length[model_name] | |||
| max_tokens = model.get("completion_params").get('max_tokens') | |||
| model_limited_tokens = model_instance.model_rules.max_tokens.max | |||
| max_tokens = model_instance.get_model_kwargs().max_tokens | |||
| if mode == 'completion' and isinstance(final_llm, BaseLLM): | |||
| prompt_tokens = final_llm.get_num_tokens(prompt) | |||
| else: | |||
| prompt_tokens = final_llm.get_num_tokens_from_messages(prompt) | |||
| if model_limited_tokens is None: | |||
| return | |||
| if max_tokens is None: | |||
| max_tokens = 0 | |||
| prompt_tokens = model_instance.get_num_tokens(prompt_messages) | |||
| if prompt_tokens + max_tokens > model_limited_tokens: | |||
| max_tokens = max(model_limited_tokens - prompt_tokens, 16) | |||
| final_llm.max_tokens = max_tokens | |||
| # update model instance max tokens | |||
| model_kwargs = model_instance.get_model_kwargs() | |||
| model_kwargs.max_tokens = max_tokens | |||
| model_instance.set_model_kwargs(model_kwargs) | |||
| @classmethod | |||
| def generate_more_like_this(cls, task_id: str, app: App, message: Message, pre_prompt: str, | |||
| app_model_config: AppModelConfig, user: Account, streaming: bool): | |||
| llm = LLMBuilder.to_llm_from_model( | |||
| final_model_instance = ModelFactory.get_text_generation_model_from_model_config( | |||
| tenant_id=app.tenant_id, | |||
| model=app_model_config.model_dict, | |||
| model_config=app_model_config.model_dict, | |||
| streaming=streaming | |||
| ) | |||
| # get llm prompt | |||
| original_prompt, _ = cls.get_main_llm_prompt( | |||
| old_prompt_messages, _ = cls.get_main_llm_prompt( | |||
| mode="completion", | |||
| llm=llm, | |||
| model=app_model_config.model_dict, | |||
| pre_prompt=pre_prompt, | |||
| query=message.query, | |||
| @@ -395,10 +371,9 @@ And answer according to the language of the user's question. | |||
| original_completion = message.answer.strip() | |||
| prompt = MORE_LIKE_THIS_GENERATE_PROMPT | |||
| prompt = prompt.format(prompt=original_prompt, original_completion=original_completion) | |||
| prompt = prompt.format(prompt=old_prompt_messages[0].content, original_completion=original_completion) | |||
| if isinstance(llm, BaseChatModel): | |||
| prompt = [HumanMessage(content=prompt)] | |||
| prompt_messages = [PromptMessage(content=prompt)] | |||
| conversation_message_task = ConversationMessageTask( | |||
| task_id=task_id, | |||
| @@ -408,16 +383,16 @@ And answer according to the language of the user's question. | |||
| inputs=message.inputs, | |||
| query=message.query, | |||
| is_override=True if message.override_model_configs else False, | |||
| streaming=streaming | |||
| streaming=streaming, | |||
| model_instance=final_model_instance | |||
| ) | |||
| llm.callbacks = cls.get_llm_callbacks(llm, streaming, conversation_message_task) | |||
| cls.recale_llm_max_tokens( | |||
| final_llm=llm, | |||
| model=app_model_config.model_dict, | |||
| prompt=prompt, | |||
| mode='completion' | |||
| model_instance=final_model_instance, | |||
| prompt_messages=prompt_messages | |||
| ) | |||
| llm.generate([prompt]) | |||
| final_model_instance.run( | |||
| messages=prompt_messages, | |||
| callbacks=[LLMCallbackHandler(final_model_instance, conversation_message_task)] | |||
| ) | |||
| @@ -1,109 +0,0 @@ | |||
| from _decimal import Decimal | |||
| models = { | |||
| 'claude-instant-1': 'anthropic', # 100,000 tokens | |||
| 'claude-2': 'anthropic', # 100,000 tokens | |||
| 'gpt-4': 'openai', # 8,192 tokens | |||
| 'gpt-4-32k': 'openai', # 32,768 tokens | |||
| 'gpt-3.5-turbo': 'openai', # 4,096 tokens | |||
| 'gpt-3.5-turbo-16k': 'openai', # 16384 tokens | |||
| 'text-davinci-003': 'openai', # 4,097 tokens | |||
| 'text-davinci-002': 'openai', # 4,097 tokens | |||
| 'text-curie-001': 'openai', # 2,049 tokens | |||
| 'text-babbage-001': 'openai', # 2,049 tokens | |||
| 'text-ada-001': 'openai', # 2,049 tokens | |||
| 'text-embedding-ada-002': 'openai', # 8191 tokens, 1536 dimensions | |||
| 'whisper-1': 'openai' | |||
| } | |||
| max_context_token_length = { | |||
| 'claude-instant-1': 100000, | |||
| 'claude-2': 100000, | |||
| 'gpt-4': 8192, | |||
| 'gpt-4-32k': 32768, | |||
| 'gpt-3.5-turbo': 4096, | |||
| 'gpt-3.5-turbo-16k': 16384, | |||
| 'text-davinci-003': 4097, | |||
| 'text-davinci-002': 4097, | |||
| 'text-curie-001': 2049, | |||
| 'text-babbage-001': 2049, | |||
| 'text-ada-001': 2049, | |||
| 'text-embedding-ada-002': 8191, | |||
| } | |||
| models_by_mode = { | |||
| 'chat': [ | |||
| 'claude-instant-1', # 100,000 tokens | |||
| 'claude-2', # 100,000 tokens | |||
| 'gpt-4', # 8,192 tokens | |||
| 'gpt-4-32k', # 32,768 tokens | |||
| 'gpt-3.5-turbo', # 4,096 tokens | |||
| 'gpt-3.5-turbo-16k', # 16,384 tokens | |||
| ], | |||
| 'completion': [ | |||
| 'claude-instant-1', # 100,000 tokens | |||
| 'claude-2', # 100,000 tokens | |||
| 'gpt-4', # 8,192 tokens | |||
| 'gpt-4-32k', # 32,768 tokens | |||
| 'gpt-3.5-turbo', # 4,096 tokens | |||
| 'gpt-3.5-turbo-16k', # 16,384 tokens | |||
| 'text-davinci-003', # 4,097 tokens | |||
| 'text-davinci-002' # 4,097 tokens | |||
| 'text-curie-001', # 2,049 tokens | |||
| 'text-babbage-001', # 2,049 tokens | |||
| 'text-ada-001' # 2,049 tokens | |||
| ], | |||
| 'embedding': [ | |||
| 'text-embedding-ada-002' # 8191 tokens, 1536 dimensions | |||
| ] | |||
| } | |||
| model_currency = 'USD' | |||
| model_prices = { | |||
| 'claude-instant-1': { | |||
| 'prompt': Decimal('0.00163'), | |||
| 'completion': Decimal('0.00551'), | |||
| }, | |||
| 'claude-2': { | |||
| 'prompt': Decimal('0.01102'), | |||
| 'completion': Decimal('0.03268'), | |||
| }, | |||
| 'gpt-4': { | |||
| 'prompt': Decimal('0.03'), | |||
| 'completion': Decimal('0.06'), | |||
| }, | |||
| 'gpt-4-32k': { | |||
| 'prompt': Decimal('0.06'), | |||
| 'completion': Decimal('0.12') | |||
| }, | |||
| 'gpt-3.5-turbo': { | |||
| 'prompt': Decimal('0.0015'), | |||
| 'completion': Decimal('0.002') | |||
| }, | |||
| 'gpt-3.5-turbo-16k': { | |||
| 'prompt': Decimal('0.003'), | |||
| 'completion': Decimal('0.004') | |||
| }, | |||
| 'text-davinci-003': { | |||
| 'prompt': Decimal('0.02'), | |||
| 'completion': Decimal('0.02') | |||
| }, | |||
| 'text-curie-001': { | |||
| 'prompt': Decimal('0.002'), | |||
| 'completion': Decimal('0.002') | |||
| }, | |||
| 'text-babbage-001': { | |||
| 'prompt': Decimal('0.0005'), | |||
| 'completion': Decimal('0.0005') | |||
| }, | |||
| 'text-ada-001': { | |||
| 'prompt': Decimal('0.0004'), | |||
| 'completion': Decimal('0.0004') | |||
| }, | |||
| 'text-embedding-ada-002': { | |||
| 'usage': Decimal('0.0001'), | |||
| } | |||
| } | |||
| agent_model_name = 'text-davinci-003' | |||
| @@ -6,9 +6,9 @@ from core.callback_handler.entity.agent_loop import AgentLoop | |||
| from core.callback_handler.entity.dataset_query import DatasetQueryObj | |||
| from core.callback_handler.entity.llm_message import LLMMessage | |||
| from core.callback_handler.entity.chain_result import ChainResult | |||
| from core.constant import llm_constant | |||
| from core.llm.llm_builder import LLMBuilder | |||
| from core.llm.provider.llm_provider_service import LLMProviderService | |||
| from core.model_providers.model_factory import ModelFactory | |||
| from core.model_providers.models.entity.message import to_prompt_messages, MessageType | |||
| from core.model_providers.models.llm.base import BaseLLM | |||
| from core.prompt.prompt_builder import PromptBuilder | |||
| from core.prompt.prompt_template import JinjaPromptTemplate | |||
| from events.message_event import message_was_created | |||
| @@ -16,12 +16,11 @@ from extensions.ext_database import db | |||
| from extensions.ext_redis import redis_client | |||
| from models.dataset import DatasetQuery | |||
| from models.model import AppModelConfig, Conversation, Account, Message, EndUser, App, MessageAgentThought, MessageChain | |||
| from models.provider import ProviderType, Provider | |||
| class ConversationMessageTask: | |||
| def __init__(self, task_id: str, app: App, app_model_config: AppModelConfig, user: Account, | |||
| inputs: dict, query: str, streaming: bool, | |||
| inputs: dict, query: str, streaming: bool, model_instance: BaseLLM, | |||
| conversation: Optional[Conversation] = None, is_override: bool = False): | |||
| self.task_id = task_id | |||
| @@ -38,9 +37,12 @@ class ConversationMessageTask: | |||
| self.conversation = conversation | |||
| self.is_new_conversation = False | |||
| self.model_instance = model_instance | |||
| self.message = None | |||
| self.model_dict = self.app_model_config.model_dict | |||
| self.provider_name = self.model_dict.get('provider') | |||
| self.model_name = self.model_dict.get('name') | |||
| self.mode = app.mode | |||
| @@ -56,9 +58,6 @@ class ConversationMessageTask: | |||
| ) | |||
| def init(self): | |||
| provider_name = LLMBuilder.get_default_provider(self.app.tenant_id, self.model_name) | |||
| self.model_dict['provider'] = provider_name | |||
| override_model_configs = None | |||
| if self.is_override: | |||
| override_model_configs = { | |||
| @@ -89,15 +88,19 @@ class ConversationMessageTask: | |||
| if self.app_model_config.pre_prompt: | |||
| system_message = PromptBuilder.to_system_message(self.app_model_config.pre_prompt, self.inputs) | |||
| system_instruction = system_message.content | |||
| llm = LLMBuilder.to_llm(self.tenant_id, self.model_name) | |||
| system_instruction_tokens = llm.get_num_tokens_from_messages([system_message]) | |||
| model_instance = ModelFactory.get_text_generation_model( | |||
| tenant_id=self.tenant_id, | |||
| model_provider_name=self.provider_name, | |||
| model_name=self.model_name | |||
| ) | |||
| system_instruction_tokens = model_instance.get_num_tokens(to_prompt_messages([system_message])) | |||
| if not self.conversation: | |||
| self.is_new_conversation = True | |||
| self.conversation = Conversation( | |||
| app_id=self.app_model_config.app_id, | |||
| app_model_config_id=self.app_model_config.id, | |||
| model_provider=self.model_dict.get('provider'), | |||
| model_provider=self.provider_name, | |||
| model_id=self.model_name, | |||
| override_model_configs=json.dumps(override_model_configs) if override_model_configs else None, | |||
| mode=self.mode, | |||
| @@ -117,7 +120,7 @@ class ConversationMessageTask: | |||
| self.message = Message( | |||
| app_id=self.app_model_config.app_id, | |||
| model_provider=self.model_dict.get('provider'), | |||
| model_provider=self.provider_name, | |||
| model_id=self.model_name, | |||
| override_model_configs=json.dumps(override_model_configs) if override_model_configs else None, | |||
| conversation_id=self.conversation.id, | |||
| @@ -131,7 +134,7 @@ class ConversationMessageTask: | |||
| answer_unit_price=0, | |||
| provider_response_latency=0, | |||
| total_price=0, | |||
| currency=llm_constant.model_currency, | |||
| currency=self.model_instance.get_currency(), | |||
| from_source=('console' if isinstance(self.user, Account) else 'api'), | |||
| from_end_user_id=(self.user.id if isinstance(self.user, EndUser) else None), | |||
| from_account_id=(self.user.id if isinstance(self.user, Account) else None), | |||
| @@ -145,12 +148,10 @@ class ConversationMessageTask: | |||
| self._pub_handler.pub_text(text) | |||
| def save_message(self, llm_message: LLMMessage, by_stopped: bool = False): | |||
| model_name = self.app_model_config.model_dict.get('name') | |||
| message_tokens = llm_message.prompt_tokens | |||
| answer_tokens = llm_message.completion_tokens | |||
| message_unit_price = llm_constant.model_prices[model_name]['prompt'] | |||
| answer_unit_price = llm_constant.model_prices[model_name]['completion'] | |||
| message_unit_price = self.model_instance.get_token_price(1, MessageType.HUMAN) | |||
| answer_unit_price = self.model_instance.get_token_price(1, MessageType.ASSISTANT) | |||
| total_price = self.calc_total_price(message_tokens, message_unit_price, answer_tokens, answer_unit_price) | |||
| @@ -163,8 +164,6 @@ class ConversationMessageTask: | |||
| self.message.provider_response_latency = llm_message.latency | |||
| self.message.total_price = total_price | |||
| self.update_provider_quota() | |||
| db.session.commit() | |||
| message_was_created.send( | |||
| @@ -176,20 +175,6 @@ class ConversationMessageTask: | |||
| if not by_stopped: | |||
| self.end() | |||
| def update_provider_quota(self): | |||
| llm_provider_service = LLMProviderService( | |||
| tenant_id=self.app.tenant_id, | |||
| provider_name=self.message.model_provider, | |||
| ) | |||
| provider = llm_provider_service.get_provider_db_record() | |||
| if provider and provider.provider_type == ProviderType.SYSTEM.value: | |||
| db.session.query(Provider).filter( | |||
| Provider.tenant_id == self.app.tenant_id, | |||
| Provider.provider_name == provider.provider_name, | |||
| Provider.quota_limit > Provider.quota_used | |||
| ).update({'quota_used': Provider.quota_used + 1}) | |||
| def init_chain(self, chain_result: ChainResult): | |||
| message_chain = MessageChain( | |||
| message_id=self.message.id, | |||
| @@ -229,10 +214,10 @@ class ConversationMessageTask: | |||
| return message_agent_thought | |||
| def on_agent_end(self, message_agent_thought: MessageAgentThought, agent_model_name: str, | |||
| def on_agent_end(self, message_agent_thought: MessageAgentThought, agent_model_instant: BaseLLM, | |||
| agent_loop: AgentLoop): | |||
| agent_message_unit_price = llm_constant.model_prices[agent_model_name]['prompt'] | |||
| agent_answer_unit_price = llm_constant.model_prices[agent_model_name]['completion'] | |||
| agent_message_unit_price = agent_model_instant.get_token_price(1, MessageType.HUMAN) | |||
| agent_answer_unit_price = agent_model_instant.get_token_price(1, MessageType.ASSISTANT) | |||
| loop_message_tokens = agent_loop.prompt_tokens | |||
| loop_answer_tokens = agent_loop.completion_tokens | |||
| @@ -253,7 +238,7 @@ class ConversationMessageTask: | |||
| message_agent_thought.latency = agent_loop.latency | |||
| message_agent_thought.tokens = agent_loop.prompt_tokens + agent_loop.completion_tokens | |||
| message_agent_thought.total_price = loop_total_price | |||
| message_agent_thought.currency = llm_constant.model_currency | |||
| message_agent_thought.currency = agent_model_instant.get_currency() | |||
| db.session.flush() | |||
| def on_dataset_query_end(self, dataset_query_obj: DatasetQueryObj): | |||
| @@ -3,7 +3,7 @@ from typing import Any, Dict, Optional, Sequence | |||
| from langchain.schema import Document | |||
| from sqlalchemy import func | |||
| from core.llm.token_calculator import TokenCalculator | |||
| from core.model_providers.model_factory import ModelFactory | |||
| from extensions.ext_database import db | |||
| from models.dataset import Dataset, DocumentSegment | |||
| @@ -13,12 +13,10 @@ class DatesetDocumentStore: | |||
| self, | |||
| dataset: Dataset, | |||
| user_id: str, | |||
| embedding_model_name: str, | |||
| document_id: Optional[str] = None, | |||
| ): | |||
| self._dataset = dataset | |||
| self._user_id = user_id | |||
| self._embedding_model_name = embedding_model_name | |||
| self._document_id = document_id | |||
| @classmethod | |||
| @@ -39,10 +37,6 @@ class DatesetDocumentStore: | |||
| def user_id(self) -> Any: | |||
| return self._user_id | |||
| @property | |||
| def embedding_model_name(self) -> Any: | |||
| return self._embedding_model_name | |||
| @property | |||
| def docs(self) -> Dict[str, Document]: | |||
| document_segments = db.session.query(DocumentSegment).filter( | |||
| @@ -74,6 +68,10 @@ class DatesetDocumentStore: | |||
| if max_position is None: | |||
| max_position = 0 | |||
| embedding_model = ModelFactory.get_embedding_model( | |||
| tenant_id=self._dataset.tenant_id | |||
| ) | |||
| for doc in docs: | |||
| if not isinstance(doc, Document): | |||
| raise ValueError("doc must be a Document") | |||
| @@ -88,7 +86,7 @@ class DatesetDocumentStore: | |||
| ) | |||
| # calc embedding use tokens | |||
| tokens = TokenCalculator.get_num_tokens(self._embedding_model_name, doc.page_content) | |||
| tokens = embedding_model.get_num_tokens(doc.page_content) | |||
| if not segment_document: | |||
| max_position += 1 | |||
| @@ -4,14 +4,14 @@ from typing import List | |||
| from langchain.embeddings.base import Embeddings | |||
| from sqlalchemy.exc import IntegrityError | |||
| from core.llm.wrappers.openai_wrapper import handle_openai_exceptions | |||
| from core.model_providers.models.embedding.base import BaseEmbedding | |||
| from extensions.ext_database import db | |||
| from libs import helper | |||
| from models.dataset import Embedding | |||
| class CacheEmbedding(Embeddings): | |||
| def __init__(self, embeddings: Embeddings): | |||
| def __init__(self, embeddings: BaseEmbedding): | |||
| self._embeddings = embeddings | |||
| def embed_documents(self, texts: List[str]) -> List[List[float]]: | |||
| @@ -21,48 +21,54 @@ class CacheEmbedding(Embeddings): | |||
| embedding_queue_texts = [] | |||
| for text in texts: | |||
| hash = helper.generate_text_hash(text) | |||
| embedding = db.session.query(Embedding).filter_by(hash=hash).first() | |||
| embedding = db.session.query(Embedding).filter_by(model_name=self._embeddings.name, hash=hash).first() | |||
| if embedding: | |||
| text_embeddings.append(embedding.get_embedding()) | |||
| else: | |||
| embedding_queue_texts.append(text) | |||
| embedding_results = self._embeddings.embed_documents(embedding_queue_texts) | |||
| if embedding_queue_texts: | |||
| try: | |||
| embedding_results = self._embeddings.client.embed_documents(embedding_queue_texts) | |||
| except Exception as ex: | |||
| raise self._embeddings.handle_exceptions(ex) | |||
| i = 0 | |||
| for text in embedding_queue_texts: | |||
| hash = helper.generate_text_hash(text) | |||
| i = 0 | |||
| for text in embedding_queue_texts: | |||
| hash = helper.generate_text_hash(text) | |||
| try: | |||
| embedding = Embedding(hash=hash) | |||
| embedding.set_embedding(embedding_results[i]) | |||
| db.session.add(embedding) | |||
| db.session.commit() | |||
| except IntegrityError: | |||
| db.session.rollback() | |||
| continue | |||
| except: | |||
| logging.exception('Failed to add embedding to db') | |||
| continue | |||
| finally: | |||
| i += 1 | |||
| try: | |||
| embedding = Embedding(model_name=self._embeddings.name, hash=hash) | |||
| embedding.set_embedding(embedding_results[i]) | |||
| db.session.add(embedding) | |||
| db.session.commit() | |||
| except IntegrityError: | |||
| db.session.rollback() | |||
| continue | |||
| except: | |||
| logging.exception('Failed to add embedding to db') | |||
| continue | |||
| finally: | |||
| i += 1 | |||
| text_embeddings.extend(embedding_results) | |||
| text_embeddings.extend(embedding_results) | |||
| return text_embeddings | |||
| @handle_openai_exceptions | |||
| def embed_query(self, text: str) -> List[float]: | |||
| """Embed query text.""" | |||
| # use doc embedding cache or store if not exists | |||
| hash = helper.generate_text_hash(text) | |||
| embedding = db.session.query(Embedding).filter_by(hash=hash).first() | |||
| embedding = db.session.query(Embedding).filter_by(model_name=self._embeddings.name, hash=hash).first() | |||
| if embedding: | |||
| return embedding.get_embedding() | |||
| embedding_results = self._embeddings.embed_query(text) | |||
| try: | |||
| embedding_results = self._embeddings.client.embed_query(text) | |||
| except Exception as ex: | |||
| raise self._embeddings.handle_exceptions(ex) | |||
| try: | |||
| embedding = Embedding(hash=hash) | |||
| embedding = Embedding(model_name=self._embeddings.name, hash=hash) | |||
| embedding.set_embedding(embedding_results) | |||
| db.session.add(embedding) | |||
| db.session.commit() | |||
| @@ -72,3 +78,5 @@ class CacheEmbedding(Embeddings): | |||
| logging.exception('Failed to add embedding to db') | |||
| return embedding_results | |||
| @@ -1,13 +1,10 @@ | |||
| import logging | |||
| from langchain import PromptTemplate | |||
| from langchain.chat_models.base import BaseChatModel | |||
| from langchain.schema import HumanMessage, OutputParserException, BaseMessage, SystemMessage | |||
| from core.constant import llm_constant | |||
| from core.llm.llm_builder import LLMBuilder | |||
| from core.llm.streamable_open_ai import StreamableOpenAI | |||
| from core.llm.token_calculator import TokenCalculator | |||
| from langchain.schema import OutputParserException | |||
| from core.model_providers.model_factory import ModelFactory | |||
| from core.model_providers.models.entity.message import PromptMessage, MessageType | |||
| from core.model_providers.models.entity.model_params import ModelKwargs | |||
| from core.prompt.output_parser.rule_config_generator import RuleConfigGeneratorOutputParser | |||
| from core.prompt.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser | |||
| @@ -15,9 +12,6 @@ from core.prompt.prompt_template import JinjaPromptTemplate, OutLinePromptTempla | |||
| from core.prompt.prompts import CONVERSATION_TITLE_PROMPT, CONVERSATION_SUMMARY_PROMPT, INTRODUCTION_GENERATE_PROMPT, \ | |||
| GENERATOR_QA_PROMPT | |||
| # gpt-3.5-turbo works not well | |||
| generate_base_model = 'text-davinci-003' | |||
| class LLMGenerator: | |||
| @classmethod | |||
| @@ -28,29 +22,35 @@ class LLMGenerator: | |||
| query = query[:300] + "...[TRUNCATED]..." + query[-300:] | |||
| prompt = prompt.format(query=query) | |||
| llm: StreamableOpenAI = LLMBuilder.to_llm( | |||
| model_instance = ModelFactory.get_text_generation_model( | |||
| tenant_id=tenant_id, | |||
| model_name='gpt-3.5-turbo', | |||
| max_tokens=50, | |||
| timeout=600 | |||
| model_kwargs=ModelKwargs( | |||
| max_tokens=50 | |||
| ) | |||
| ) | |||
| if isinstance(llm, BaseChatModel): | |||
| prompt = [HumanMessage(content=prompt)] | |||
| response = llm.generate([prompt]) | |||
| answer = response.generations[0][0].text | |||
| prompts = [PromptMessage(content=prompt)] | |||
| response = model_instance.run(prompts) | |||
| answer = response.content | |||
| return answer.strip() | |||
| @classmethod | |||
| def generate_conversation_summary(cls, tenant_id: str, messages): | |||
| max_tokens = 200 | |||
| model = 'gpt-3.5-turbo' | |||
| model_instance = ModelFactory.get_text_generation_model( | |||
| tenant_id=tenant_id, | |||
| model_kwargs=ModelKwargs( | |||
| max_tokens=max_tokens | |||
| ) | |||
| ) | |||
| prompt = CONVERSATION_SUMMARY_PROMPT | |||
| prompt_with_empty_context = prompt.format(context='') | |||
| prompt_tokens = TokenCalculator.get_num_tokens(model, prompt_with_empty_context) | |||
| rest_tokens = llm_constant.max_context_token_length[model] - prompt_tokens - max_tokens - 1 | |||
| prompt_tokens = model_instance.get_num_tokens([PromptMessage(content=prompt_with_empty_context)]) | |||
| max_context_token_length = model_instance.model_rules.max_tokens.max | |||
| rest_tokens = max_context_token_length - prompt_tokens - max_tokens - 1 | |||
| context = '' | |||
| for message in messages: | |||
| @@ -68,25 +68,16 @@ class LLMGenerator: | |||
| answer = message.answer | |||
| message_qa_text = "\n\nHuman:" + query + "\n\nAssistant:" + answer | |||
| if rest_tokens - TokenCalculator.get_num_tokens(model, context + message_qa_text) > 0: | |||
| if rest_tokens - model_instance.get_num_tokens([PromptMessage(content=context + message_qa_text)]) > 0: | |||
| context += message_qa_text | |||
| if not context: | |||
| return '[message too long, no summary]' | |||
| prompt = prompt.format(context=context) | |||
| llm: StreamableOpenAI = LLMBuilder.to_llm( | |||
| tenant_id=tenant_id, | |||
| model_name=model, | |||
| max_tokens=max_tokens | |||
| ) | |||
| if isinstance(llm, BaseChatModel): | |||
| prompt = [HumanMessage(content=prompt)] | |||
| response = llm.generate([prompt]) | |||
| answer = response.generations[0][0].text | |||
| prompts = [PromptMessage(content=prompt)] | |||
| response = model_instance.run(prompts) | |||
| answer = response.content | |||
| return answer.strip() | |||
| @classmethod | |||
| @@ -94,16 +85,13 @@ class LLMGenerator: | |||
| prompt = INTRODUCTION_GENERATE_PROMPT | |||
| prompt = prompt.format(prompt=pre_prompt) | |||
| llm: StreamableOpenAI = LLMBuilder.to_llm( | |||
| tenant_id=tenant_id, | |||
| model_name=generate_base_model, | |||
| model_instance = ModelFactory.get_text_generation_model( | |||
| tenant_id=tenant_id | |||
| ) | |||
| if isinstance(llm, BaseChatModel): | |||
| prompt = [HumanMessage(content=prompt)] | |||
| response = llm.generate([prompt]) | |||
| answer = response.generations[0][0].text | |||
| prompts = [PromptMessage(content=prompt)] | |||
| response = model_instance.run(prompts) | |||
| answer = response.content | |||
| return answer.strip() | |||
| @classmethod | |||
| @@ -119,23 +107,19 @@ class LLMGenerator: | |||
| _input = prompt.format_prompt(histories=histories) | |||
| llm: StreamableOpenAI = LLMBuilder.to_llm( | |||
| model_instance = ModelFactory.get_text_generation_model( | |||
| tenant_id=tenant_id, | |||
| model_name='gpt-3.5-turbo', | |||
| temperature=0, | |||
| max_tokens=256 | |||
| model_kwargs=ModelKwargs( | |||
| max_tokens=256, | |||
| temperature=0 | |||
| ) | |||
| ) | |||
| if isinstance(llm, BaseChatModel): | |||
| query = [HumanMessage(content=_input.to_string())] | |||
| else: | |||
| query = _input.to_string() | |||
| prompts = [PromptMessage(content=_input.to_string())] | |||
| try: | |||
| output = llm(query) | |||
| if isinstance(output, BaseMessage): | |||
| output = output.content | |||
| questions = output_parser.parse(output) | |||
| output = model_instance.run(prompts) | |||
| questions = output_parser.parse(output.content) | |||
| except Exception: | |||
| logging.exception("Error generating suggested questions after answer") | |||
| questions = [] | |||
| @@ -160,21 +144,19 @@ class LLMGenerator: | |||
| _input = prompt.format_prompt(audiences=audiences, hoping_to_solve=hoping_to_solve) | |||
| llm: StreamableOpenAI = LLMBuilder.to_llm( | |||
| model_instance = ModelFactory.get_text_generation_model( | |||
| tenant_id=tenant_id, | |||
| model_name=generate_base_model, | |||
| temperature=0, | |||
| max_tokens=512 | |||
| model_kwargs=ModelKwargs( | |||
| max_tokens=512, | |||
| temperature=0 | |||
| ) | |||
| ) | |||
| if isinstance(llm, BaseChatModel): | |||
| query = [HumanMessage(content=_input.to_string())] | |||
| else: | |||
| query = _input.to_string() | |||
| prompts = [PromptMessage(content=_input.to_string())] | |||
| try: | |||
| output = llm(query) | |||
| rule_config = output_parser.parse(output) | |||
| output = model_instance.run(prompts) | |||
| rule_config = output_parser.parse(output.content) | |||
| except OutputParserException: | |||
| raise ValueError('Please give a valid input for intended audience or hoping to solve problems.') | |||
| except Exception: | |||
| @@ -188,25 +170,21 @@ class LLMGenerator: | |||
| return rule_config | |||
| @classmethod | |||
| async def generate_qa_document(cls, llm: StreamableOpenAI, query): | |||
| def generate_qa_document(cls, tenant_id: str, query): | |||
| prompt = GENERATOR_QA_PROMPT | |||
| model_instance = ModelFactory.get_text_generation_model( | |||
| tenant_id=tenant_id, | |||
| model_kwargs=ModelKwargs( | |||
| max_tokens=2000 | |||
| ) | |||
| ) | |||
| if isinstance(llm, BaseChatModel): | |||
| prompt = [SystemMessage(content=prompt), HumanMessage(content=query)] | |||
| response = llm.generate([prompt]) | |||
| answer = response.generations[0][0].text | |||
| return answer.strip() | |||
| @classmethod | |||
| def generate_qa_document_sync(cls, llm: StreamableOpenAI, query): | |||
| prompt = GENERATOR_QA_PROMPT | |||
| if isinstance(llm, BaseChatModel): | |||
| prompt = [SystemMessage(content=prompt), HumanMessage(content=query)] | |||
| prompts = [ | |||
| PromptMessage(content=prompt, type=MessageType.SYSTEM), | |||
| PromptMessage(content=query) | |||
| ] | |||
| response = llm.generate([prompt]) | |||
| answer = response.generations[0][0].text | |||
| response = model_instance.run(prompts) | |||
| answer = response.content | |||
| return answer.strip() | |||
| @@ -0,0 +1,20 @@ | |||
| import base64 | |||
| from extensions.ext_database import db | |||
| from libs import rsa | |||
| from models.account import Tenant | |||
| def obfuscated_token(token: str): | |||
| return token[:6] + '*' * (len(token) - 8) + token[-2:] | |||
| def encrypt_token(tenant_id: str, token: str): | |||
| tenant = db.session.query(Tenant).filter(Tenant.id == tenant_id).first() | |||
| encrypted_token = rsa.encrypt(token, tenant.encrypt_public_key) | |||
| return base64.b64encode(encrypted_token).decode() | |||
| def decrypt_token(tenant_id: str, token: str): | |||
| return rsa.decrypt(base64.b64decode(token), tenant_id) | |||
| @@ -1,10 +1,9 @@ | |||
| from flask import current_app | |||
| from langchain.embeddings import OpenAIEmbeddings | |||
| from core.embedding.cached_embedding import CacheEmbedding | |||
| from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex, KeywordTableConfig | |||
| from core.index.vector_index.vector_index import VectorIndex | |||
| from core.llm.llm_builder import LLMBuilder | |||
| from core.model_providers.model_factory import ModelFactory | |||
| from models.dataset import Dataset | |||
| @@ -15,16 +14,11 @@ class IndexBuilder: | |||
| if not ignore_high_quality_check and dataset.indexing_technique != 'high_quality': | |||
| return None | |||
| model_credentials = LLMBuilder.get_model_credentials( | |||
| tenant_id=dataset.tenant_id, | |||
| model_provider=LLMBuilder.get_default_provider(dataset.tenant_id, 'text-embedding-ada-002'), | |||
| model_name='text-embedding-ada-002' | |||
| embedding_model = ModelFactory.get_embedding_model( | |||
| tenant_id=dataset.tenant_id | |||
| ) | |||
| embeddings = CacheEmbedding(OpenAIEmbeddings( | |||
| max_retries=1, | |||
| **model_credentials | |||
| )) | |||
| embeddings = CacheEmbedding(embedding_model) | |||
| return VectorIndex( | |||
| dataset=dataset, | |||
| @@ -1,4 +1,3 @@ | |||
| import concurrent | |||
| import datetime | |||
| import json | |||
| import logging | |||
| @@ -6,7 +5,6 @@ import re | |||
| import threading | |||
| import time | |||
| import uuid | |||
| from concurrent.futures import ThreadPoolExecutor | |||
| from typing import Optional, List, cast | |||
| from flask_login import current_user | |||
| @@ -18,11 +16,10 @@ from core.data_loader.loader.notion import NotionLoader | |||
| from core.docstore.dataset_docstore import DatesetDocumentStore | |||
| from core.generator.llm_generator import LLMGenerator | |||
| from core.index.index import IndexBuilder | |||
| from core.llm.error import ProviderTokenNotInitError | |||
| from core.llm.llm_builder import LLMBuilder | |||
| from core.llm.streamable_open_ai import StreamableOpenAI | |||
| from core.model_providers.error import ProviderTokenNotInitError | |||
| from core.model_providers.model_factory import ModelFactory | |||
| from core.model_providers.models.entity.message import MessageType | |||
| from core.spiltter.fixed_text_splitter import FixedRecursiveCharacterTextSplitter | |||
| from core.llm.token_calculator import TokenCalculator | |||
| from extensions.ext_database import db | |||
| from extensions.ext_redis import redis_client | |||
| from extensions.ext_storage import storage | |||
| @@ -35,9 +32,8 @@ from models.source import DataSourceBinding | |||
| class IndexingRunner: | |||
| def __init__(self, embedding_model_name: str = "text-embedding-ada-002"): | |||
| def __init__(self): | |||
| self.storage = storage | |||
| self.embedding_model_name = embedding_model_name | |||
| def run(self, dataset_documents: List[DatasetDocument]): | |||
| """Run the indexing process.""" | |||
| @@ -227,11 +223,15 @@ class IndexingRunner: | |||
| dataset_document.stopped_at = datetime.datetime.utcnow() | |||
| db.session.commit() | |||
| def file_indexing_estimate(self, file_details: List[UploadFile], tmp_processing_rule: dict, | |||
| def file_indexing_estimate(self, tenant_id: str, file_details: List[UploadFile], tmp_processing_rule: dict, | |||
| doc_form: str = None) -> dict: | |||
| """ | |||
| Estimate the indexing for the document. | |||
| """ | |||
| embedding_model = ModelFactory.get_embedding_model( | |||
| tenant_id=tenant_id | |||
| ) | |||
| tokens = 0 | |||
| preview_texts = [] | |||
| total_segments = 0 | |||
| @@ -253,44 +253,49 @@ class IndexingRunner: | |||
| splitter=splitter, | |||
| processing_rule=processing_rule | |||
| ) | |||
| total_segments += len(documents) | |||
| for document in documents: | |||
| if len(preview_texts) < 5: | |||
| preview_texts.append(document.page_content) | |||
| tokens += TokenCalculator.get_num_tokens(self.embedding_model_name, | |||
| self.filter_string(document.page_content)) | |||
| tokens += embedding_model.get_num_tokens(self.filter_string(document.page_content)) | |||
| text_generation_model = ModelFactory.get_text_generation_model( | |||
| tenant_id=tenant_id | |||
| ) | |||
| if doc_form and doc_form == 'qa_model': | |||
| if len(preview_texts) > 0: | |||
| # qa model document | |||
| llm: StreamableOpenAI = LLMBuilder.to_llm( | |||
| tenant_id=current_user.current_tenant_id, | |||
| model_name='gpt-3.5-turbo', | |||
| max_tokens=2000 | |||
| ) | |||
| response = LLMGenerator.generate_qa_document_sync(llm, preview_texts[0]) | |||
| response = LLMGenerator.generate_qa_document(current_user.current_tenant_id, preview_texts[0]) | |||
| document_qa_list = self.format_split_text(response) | |||
| return { | |||
| "total_segments": total_segments * 20, | |||
| "tokens": total_segments * 2000, | |||
| "total_price": '{:f}'.format( | |||
| TokenCalculator.get_token_price('gpt-3.5-turbo', total_segments * 2000, 'completion')), | |||
| "currency": TokenCalculator.get_currency(self.embedding_model_name), | |||
| text_generation_model.get_token_price(total_segments * 2000, MessageType.HUMAN)), | |||
| "currency": embedding_model.get_currency(), | |||
| "qa_preview": document_qa_list, | |||
| "preview": preview_texts | |||
| } | |||
| return { | |||
| "total_segments": total_segments, | |||
| "tokens": tokens, | |||
| "total_price": '{:f}'.format(TokenCalculator.get_token_price(self.embedding_model_name, tokens)), | |||
| "currency": TokenCalculator.get_currency(self.embedding_model_name), | |||
| "total_price": '{:f}'.format(embedding_model.get_token_price(tokens)), | |||
| "currency": embedding_model.get_currency(), | |||
| "preview": preview_texts | |||
| } | |||
| def notion_indexing_estimate(self, notion_info_list: list, tmp_processing_rule: dict, doc_form: str = None) -> dict: | |||
| def notion_indexing_estimate(self, tenant_id: str, notion_info_list: list, tmp_processing_rule: dict, doc_form: str = None) -> dict: | |||
| """ | |||
| Estimate the indexing for the document. | |||
| """ | |||
| embedding_model = ModelFactory.get_embedding_model( | |||
| tenant_id=tenant_id | |||
| ) | |||
| # load data from notion | |||
| tokens = 0 | |||
| preview_texts = [] | |||
| @@ -336,31 +341,31 @@ class IndexingRunner: | |||
| if len(preview_texts) < 5: | |||
| preview_texts.append(document.page_content) | |||
| tokens += TokenCalculator.get_num_tokens(self.embedding_model_name, document.page_content) | |||
| tokens += embedding_model.get_num_tokens(document.page_content) | |||
| text_generation_model = ModelFactory.get_text_generation_model( | |||
| tenant_id=tenant_id | |||
| ) | |||
| if doc_form and doc_form == 'qa_model': | |||
| if len(preview_texts) > 0: | |||
| # qa model document | |||
| llm: StreamableOpenAI = LLMBuilder.to_llm( | |||
| tenant_id=current_user.current_tenant_id, | |||
| model_name='gpt-3.5-turbo', | |||
| max_tokens=2000 | |||
| ) | |||
| response = LLMGenerator.generate_qa_document_sync(llm, preview_texts[0]) | |||
| response = LLMGenerator.generate_qa_document(current_user.current_tenant_id, preview_texts[0]) | |||
| document_qa_list = self.format_split_text(response) | |||
| return { | |||
| "total_segments": total_segments * 20, | |||
| "tokens": total_segments * 2000, | |||
| "total_price": '{:f}'.format( | |||
| TokenCalculator.get_token_price('gpt-3.5-turbo', total_segments * 2000, 'completion')), | |||
| "currency": TokenCalculator.get_currency(self.embedding_model_name), | |||
| text_generation_model.get_token_price(total_segments * 2000, MessageType.HUMAN)), | |||
| "currency": embedding_model.get_currency(), | |||
| "qa_preview": document_qa_list, | |||
| "preview": preview_texts | |||
| } | |||
| return { | |||
| "total_segments": total_segments, | |||
| "tokens": tokens, | |||
| "total_price": '{:f}'.format(TokenCalculator.get_token_price(self.embedding_model_name, tokens)), | |||
| "currency": TokenCalculator.get_currency(self.embedding_model_name), | |||
| "total_price": '{:f}'.format(embedding_model.get_token_price(tokens)), | |||
| "currency": embedding_model.get_currency(), | |||
| "preview": preview_texts | |||
| } | |||
| @@ -459,7 +464,6 @@ class IndexingRunner: | |||
| doc_store = DatesetDocumentStore( | |||
| dataset=dataset, | |||
| user_id=dataset_document.created_by, | |||
| embedding_model_name=self.embedding_model_name, | |||
| document_id=dataset_document.id | |||
| ) | |||
| @@ -513,17 +517,12 @@ class IndexingRunner: | |||
| all_documents.extend(split_documents) | |||
| # processing qa document | |||
| if document_form == 'qa_model': | |||
| llm: StreamableOpenAI = LLMBuilder.to_llm( | |||
| tenant_id=tenant_id, | |||
| model_name='gpt-3.5-turbo', | |||
| max_tokens=2000 | |||
| ) | |||
| for i in range(0, len(all_documents), 10): | |||
| threads = [] | |||
| sub_documents = all_documents[i:i + 10] | |||
| for doc in sub_documents: | |||
| document_format_thread = threading.Thread(target=self.format_qa_document, kwargs={ | |||
| 'llm': llm, 'document_node': doc, 'all_qa_documents': all_qa_documents}) | |||
| 'tenant_id': tenant_id, 'document_node': doc, 'all_qa_documents': all_qa_documents}) | |||
| threads.append(document_format_thread) | |||
| document_format_thread.start() | |||
| for thread in threads: | |||
| @@ -531,13 +530,13 @@ class IndexingRunner: | |||
| return all_qa_documents | |||
| return all_documents | |||
| def format_qa_document(self, llm: StreamableOpenAI, document_node, all_qa_documents): | |||
| def format_qa_document(self, tenant_id: str, document_node, all_qa_documents): | |||
| format_documents = [] | |||
| if document_node.page_content is None or not document_node.page_content.strip(): | |||
| return | |||
| try: | |||
| # qa model document | |||
| response = LLMGenerator.generate_qa_document_sync(llm, document_node.page_content) | |||
| response = LLMGenerator.generate_qa_document(tenant_id, document_node.page_content) | |||
| document_qa_list = self.format_split_text(response) | |||
| qa_documents = [] | |||
| for result in document_qa_list: | |||
| @@ -638,6 +637,10 @@ class IndexingRunner: | |||
| vector_index = IndexBuilder.get_index(dataset, 'high_quality') | |||
| keyword_table_index = IndexBuilder.get_index(dataset, 'economy') | |||
| embedding_model = ModelFactory.get_embedding_model( | |||
| tenant_id=dataset.tenant_id | |||
| ) | |||
| # chunk nodes by chunk size | |||
| indexing_start_at = time.perf_counter() | |||
| tokens = 0 | |||
| @@ -648,7 +651,7 @@ class IndexingRunner: | |||
| chunk_documents = documents[i:i + chunk_size] | |||
| tokens += sum( | |||
| TokenCalculator.get_num_tokens(self.embedding_model_name, document.page_content) | |||
| embedding_model.get_num_tokens(document.page_content) | |||
| for document in chunk_documents | |||
| ) | |||
| @@ -1,148 +0,0 @@ | |||
| from typing import Union, Optional, List | |||
| from langchain.callbacks.base import BaseCallbackHandler | |||
| from core.constant import llm_constant | |||
| from core.llm.error import ProviderTokenNotInitError | |||
| from core.llm.provider.base import BaseProvider | |||
| from core.llm.provider.llm_provider_service import LLMProviderService | |||
| from core.llm.streamable_azure_chat_open_ai import StreamableAzureChatOpenAI | |||
| from core.llm.streamable_azure_open_ai import StreamableAzureOpenAI | |||
| from core.llm.streamable_chat_anthropic import StreamableChatAnthropic | |||
| from core.llm.streamable_chat_open_ai import StreamableChatOpenAI | |||
| from core.llm.streamable_open_ai import StreamableOpenAI | |||
| from models.provider import ProviderType, ProviderName | |||
| class LLMBuilder: | |||
| """ | |||
| This class handles the following logic: | |||
| 1. For providers with the name 'OpenAI', the OPENAI_API_KEY value is stored directly in encrypted_config. | |||
| 2. For providers with the name 'Azure OpenAI', encrypted_config stores the serialized values of four fields, as shown below: | |||
| OPENAI_API_TYPE=azure | |||
| OPENAI_API_VERSION=2022-12-01 | |||
| OPENAI_API_BASE=https://your-resource-name.openai.azure.com | |||
| OPENAI_API_KEY=<your Azure OpenAI API key> | |||
| 3. For providers with the name 'Anthropic', the ANTHROPIC_API_KEY value is stored directly in encrypted_config. | |||
| 4. For providers with the name 'Cohere', the COHERE_API_KEY value is stored directly in encrypted_config. | |||
| 5. For providers with the name 'HUGGINGFACEHUB', the HUGGINGFACEHUB_API_KEY value is stored directly in encrypted_config. | |||
| 6. Providers with the provider_type 'CUSTOM' can be created through the admin interface, while 'System' providers cannot be created through the admin interface. | |||
| 7. If both CUSTOM and System providers exist in the records, the CUSTOM provider is preferred by default, but this preference can be changed via an input parameter. | |||
| 8. For providers with the provider_type 'System', the quota_used must not exceed quota_limit. If the quota is exceeded, the provider cannot be used. Currently, only the TRIAL quota_type is supported, which is permanently non-resetting. | |||
| """ | |||
| @classmethod | |||
| def to_llm(cls, tenant_id: str, model_name: str, **kwargs) -> Union[StreamableOpenAI, StreamableChatOpenAI]: | |||
| provider = cls.get_default_provider(tenant_id, model_name) | |||
| model_credentials = cls.get_model_credentials(tenant_id, provider, model_name) | |||
| llm_cls = None | |||
| mode = cls.get_mode_by_model(model_name) | |||
| if mode == 'chat': | |||
| if provider == ProviderName.OPENAI.value: | |||
| llm_cls = StreamableChatOpenAI | |||
| elif provider == ProviderName.AZURE_OPENAI.value: | |||
| llm_cls = StreamableAzureChatOpenAI | |||
| elif provider == ProviderName.ANTHROPIC.value: | |||
| llm_cls = StreamableChatAnthropic | |||
| elif mode == 'completion': | |||
| if provider == ProviderName.OPENAI.value: | |||
| llm_cls = StreamableOpenAI | |||
| elif provider == ProviderName.AZURE_OPENAI.value: | |||
| llm_cls = StreamableAzureOpenAI | |||
| if not llm_cls: | |||
| raise ValueError(f"model name {model_name} is not supported.") | |||
| model_kwargs = { | |||
| 'model_name': model_name, | |||
| 'temperature': kwargs.get('temperature', 0), | |||
| 'max_tokens': kwargs.get('max_tokens', 256), | |||
| 'top_p': kwargs.get('top_p', 1), | |||
| 'frequency_penalty': kwargs.get('frequency_penalty', 0), | |||
| 'presence_penalty': kwargs.get('presence_penalty', 0), | |||
| 'callbacks': kwargs.get('callbacks', None), | |||
| 'streaming': kwargs.get('streaming', False), | |||
| } | |||
| model_kwargs.update(model_credentials) | |||
| model_kwargs = llm_cls.get_kwargs_from_model_params(model_kwargs) | |||
| return llm_cls(**model_kwargs) | |||
| @classmethod | |||
| def to_llm_from_model(cls, tenant_id: str, model: dict, streaming: bool = False, | |||
| callbacks: Optional[List[BaseCallbackHandler]] = None) -> Union[StreamableOpenAI, StreamableChatOpenAI]: | |||
| model_name = model.get("name") | |||
| completion_params = model.get("completion_params", {}) | |||
| return cls.to_llm( | |||
| tenant_id=tenant_id, | |||
| model_name=model_name, | |||
| temperature=completion_params.get('temperature', 0), | |||
| max_tokens=completion_params.get('max_tokens', 256), | |||
| top_p=completion_params.get('top_p', 0), | |||
| frequency_penalty=completion_params.get('frequency_penalty', 0.1), | |||
| presence_penalty=completion_params.get('presence_penalty', 0.1), | |||
| streaming=streaming, | |||
| callbacks=callbacks | |||
| ) | |||
| @classmethod | |||
| def get_mode_by_model(cls, model_name: str) -> str: | |||
| if not model_name: | |||
| raise ValueError(f"empty model name is not supported.") | |||
| if model_name in llm_constant.models_by_mode['chat']: | |||
| return "chat" | |||
| elif model_name in llm_constant.models_by_mode['completion']: | |||
| return "completion" | |||
| else: | |||
| raise ValueError(f"model name {model_name} is not supported.") | |||
| @classmethod | |||
| def get_model_credentials(cls, tenant_id: str, model_provider: str, model_name: str) -> dict: | |||
| """ | |||
| Returns the API credentials for the given tenant_id and model_name, based on the model's provider. | |||
| Raises an exception if the model_name is not found or if the provider is not found. | |||
| """ | |||
| if not model_name: | |||
| raise Exception('model name not found') | |||
| # | |||
| # if model_name not in llm_constant.models: | |||
| # raise Exception('model {} not found'.format(model_name)) | |||
| # model_provider = llm_constant.models[model_name] | |||
| provider_service = LLMProviderService(tenant_id=tenant_id, provider_name=model_provider) | |||
| return provider_service.get_credentials(model_name) | |||
| @classmethod | |||
| def get_default_provider(cls, tenant_id: str, model_name: str) -> str: | |||
| provider_name = llm_constant.models[model_name] | |||
| if provider_name == 'openai': | |||
| # get the default provider (openai / azure_openai) for the tenant | |||
| openai_provider = BaseProvider.get_valid_provider(tenant_id, ProviderName.OPENAI.value) | |||
| azure_openai_provider = BaseProvider.get_valid_provider(tenant_id, ProviderName.AZURE_OPENAI.value) | |||
| provider = None | |||
| if openai_provider and openai_provider.provider_type == ProviderType.CUSTOM.value: | |||
| provider = openai_provider | |||
| elif azure_openai_provider and azure_openai_provider.provider_type == ProviderType.CUSTOM.value: | |||
| provider = azure_openai_provider | |||
| elif openai_provider and openai_provider.provider_type == ProviderType.SYSTEM.value: | |||
| provider = openai_provider | |||
| elif azure_openai_provider and azure_openai_provider.provider_type == ProviderType.SYSTEM.value: | |||
| provider = azure_openai_provider | |||
| if not provider: | |||
| raise ProviderTokenNotInitError( | |||
| f"No valid {provider_name} model provider credentials found. " | |||
| f"Please go to Settings -> Model Provider to complete your provider credentials." | |||
| ) | |||
| provider_name = provider.provider_name | |||
| return provider_name | |||
| @@ -1,15 +0,0 @@ | |||
| import openai | |||
| from models.provider import ProviderName | |||
| class Moderation: | |||
| def __init__(self, provider: str, api_key: str): | |||
| self.provider = provider | |||
| self.api_key = api_key | |||
| if self.provider == ProviderName.OPENAI.value: | |||
| self.client = openai.Moderation | |||
| def moderate(self, text): | |||
| return self.client.create(input=text, api_key=self.api_key) | |||
| @@ -1,138 +0,0 @@ | |||
| import json | |||
| import logging | |||
| from typing import Optional, Union | |||
| import anthropic | |||
| from langchain.chat_models import ChatAnthropic | |||
| from langchain.schema import HumanMessage | |||
| from core import hosted_llm_credentials | |||
| from core.llm.error import ProviderTokenNotInitError | |||
| from core.llm.provider.base import BaseProvider | |||
| from core.llm.provider.errors import ValidateFailedError | |||
| from models.provider import ProviderName, ProviderType | |||
| class AnthropicProvider(BaseProvider): | |||
| def get_models(self, model_id: Optional[str] = None) -> list[dict]: | |||
| return [ | |||
| { | |||
| 'id': 'claude-instant-1', | |||
| 'name': 'claude-instant-1', | |||
| }, | |||
| { | |||
| 'id': 'claude-2', | |||
| 'name': 'claude-2', | |||
| }, | |||
| ] | |||
| def get_credentials(self, model_id: Optional[str] = None) -> dict: | |||
| return self.get_provider_api_key(model_id=model_id) | |||
| def get_provider_name(self): | |||
| return ProviderName.ANTHROPIC | |||
| def get_provider_configs(self, obfuscated: bool = False, only_custom: bool = False) -> Union[str | dict]: | |||
| """ | |||
| Returns the provider configs. | |||
| """ | |||
| try: | |||
| config = self.get_provider_api_key(only_custom=only_custom) | |||
| except: | |||
| config = { | |||
| 'anthropic_api_key': '' | |||
| } | |||
| if obfuscated: | |||
| if not config.get('anthropic_api_key'): | |||
| config = { | |||
| 'anthropic_api_key': '' | |||
| } | |||
| config['anthropic_api_key'] = self.obfuscated_token(config.get('anthropic_api_key')) | |||
| return config | |||
| return config | |||
| def get_encrypted_token(self, config: Union[dict | str]): | |||
| """ | |||
| Returns the encrypted token. | |||
| """ | |||
| return json.dumps({ | |||
| 'anthropic_api_key': self.encrypt_token(config['anthropic_api_key']) | |||
| }) | |||
| def get_decrypted_token(self, token: str): | |||
| """ | |||
| Returns the decrypted token. | |||
| """ | |||
| config = json.loads(token) | |||
| config['anthropic_api_key'] = self.decrypt_token(config['anthropic_api_key']) | |||
| return config | |||
| def get_token_type(self): | |||
| return dict | |||
| def config_validate(self, config: Union[dict | str]): | |||
| """ | |||
| Validates the given config. | |||
| """ | |||
| # check OpenAI / Azure OpenAI credential is valid | |||
| openai_provider = BaseProvider.get_valid_provider(self.tenant_id, ProviderName.OPENAI.value) | |||
| azure_openai_provider = BaseProvider.get_valid_provider(self.tenant_id, ProviderName.AZURE_OPENAI.value) | |||
| provider = None | |||
| if openai_provider: | |||
| provider = openai_provider | |||
| elif azure_openai_provider: | |||
| provider = azure_openai_provider | |||
| if not provider: | |||
| raise ValidateFailedError(f"OpenAI or Azure OpenAI provider must be configured first.") | |||
| if provider.provider_type == ProviderType.SYSTEM.value: | |||
| quota_used = provider.quota_used if provider.quota_used is not None else 0 | |||
| quota_limit = provider.quota_limit if provider.quota_limit is not None else 0 | |||
| if quota_used >= quota_limit: | |||
| raise ValidateFailedError(f"Your quota for Dify Hosted OpenAI has been exhausted, " | |||
| f"please configure OpenAI or Azure OpenAI provider first.") | |||
| try: | |||
| if not isinstance(config, dict): | |||
| raise ValueError('Config must be a object.') | |||
| if 'anthropic_api_key' not in config: | |||
| raise ValueError('anthropic_api_key must be provided.') | |||
| chat_llm = ChatAnthropic( | |||
| model='claude-instant-1', | |||
| anthropic_api_key=config['anthropic_api_key'], | |||
| max_tokens_to_sample=10, | |||
| temperature=0, | |||
| default_request_timeout=60 | |||
| ) | |||
| messages = [ | |||
| HumanMessage( | |||
| content="ping" | |||
| ) | |||
| ] | |||
| chat_llm(messages) | |||
| except anthropic.APIConnectionError as ex: | |||
| raise ValidateFailedError(f"Anthropic: Connection error, cause: {ex.__cause__}") | |||
| except (anthropic.APIStatusError, anthropic.RateLimitError) as ex: | |||
| raise ValidateFailedError(f"Anthropic: Error code: {ex.status_code} - " | |||
| f"{ex.body['error']['type']}: {ex.body['error']['message']}") | |||
| except Exception as ex: | |||
| logging.exception('Anthropic config validation failed') | |||
| raise ex | |||
| def get_hosted_credentials(self) -> Union[str | dict]: | |||
| if not hosted_llm_credentials.anthropic or not hosted_llm_credentials.anthropic.api_key: | |||
| raise ProviderTokenNotInitError( | |||
| f"No valid {self.get_provider_name().value} model provider credentials found. " | |||
| f"Please go to Settings -> Model Provider to complete your provider credentials." | |||
| ) | |||
| return {'anthropic_api_key': hosted_llm_credentials.anthropic.api_key} | |||
| @@ -1,145 +0,0 @@ | |||
| import json | |||
| import logging | |||
| from typing import Optional, Union | |||
| import openai | |||
| import requests | |||
| from core.llm.provider.base import BaseProvider | |||
| from core.llm.provider.errors import ValidateFailedError | |||
| from models.provider import ProviderName | |||
| AZURE_OPENAI_API_VERSION = '2023-07-01-preview' | |||
| class AzureProvider(BaseProvider): | |||
| def get_models(self, model_id: Optional[str] = None, credentials: Optional[dict] = None) -> list[dict]: | |||
| return [] | |||
| def check_embedding_model(self, credentials: Optional[dict] = None): | |||
| credentials = self.get_credentials('text-embedding-ada-002') if not credentials else credentials | |||
| try: | |||
| result = openai.Embedding.create(input=['test'], | |||
| engine='text-embedding-ada-002', | |||
| timeout=60, | |||
| api_key=str(credentials.get('openai_api_key')), | |||
| api_base=str(credentials.get('openai_api_base')), | |||
| api_type='azure', | |||
| api_version=str(credentials.get('openai_api_version')))["data"][0][ | |||
| "embedding"] | |||
| except openai.error.AuthenticationError as e: | |||
| raise AzureAuthenticationError(str(e)) | |||
| except openai.error.APIConnectionError as e: | |||
| raise AzureRequestFailedError( | |||
| 'Failed to request Azure OpenAI, please check your API Base Endpoint, The format is `https://xxx.openai.azure.com/`') | |||
| except openai.error.InvalidRequestError as e: | |||
| if e.http_status == 404: | |||
| raise AzureRequestFailedError("Please check your 'gpt-3.5-turbo' or 'text-embedding-ada-002' " | |||
| "deployment name is exists in Azure AI") | |||
| else: | |||
| raise AzureRequestFailedError( | |||
| 'Failed to request Azure OpenAI. cause: {}'.format(str(e))) | |||
| except openai.error.OpenAIError as e: | |||
| raise AzureRequestFailedError( | |||
| 'Failed to request Azure OpenAI. cause: {}'.format(str(e))) | |||
| if not isinstance(result, list): | |||
| raise AzureRequestFailedError('Failed to request Azure OpenAI.') | |||
| def get_credentials(self, model_id: Optional[str] = None) -> dict: | |||
| """ | |||
| Returns the API credentials for Azure OpenAI as a dictionary. | |||
| """ | |||
| config = self.get_provider_api_key(model_id=model_id) | |||
| config['openai_api_type'] = 'azure' | |||
| config['openai_api_version'] = AZURE_OPENAI_API_VERSION | |||
| if model_id == 'text-embedding-ada-002': | |||
| config['deployment'] = model_id.replace('.', '') if model_id else None | |||
| config['chunk_size'] = 16 | |||
| else: | |||
| config['deployment_name'] = model_id.replace('.', '') if model_id else None | |||
| return config | |||
| def get_provider_name(self): | |||
| return ProviderName.AZURE_OPENAI | |||
| def get_provider_configs(self, obfuscated: bool = False, only_custom: bool = False) -> Union[str | dict]: | |||
| """ | |||
| Returns the provider configs. | |||
| """ | |||
| try: | |||
| config = self.get_provider_api_key(only_custom=only_custom) | |||
| except: | |||
| config = { | |||
| 'openai_api_type': 'azure', | |||
| 'openai_api_version': AZURE_OPENAI_API_VERSION, | |||
| 'openai_api_base': '', | |||
| 'openai_api_key': '' | |||
| } | |||
| if obfuscated: | |||
| if not config.get('openai_api_key'): | |||
| config = { | |||
| 'openai_api_type': 'azure', | |||
| 'openai_api_version': AZURE_OPENAI_API_VERSION, | |||
| 'openai_api_base': '', | |||
| 'openai_api_key': '' | |||
| } | |||
| config['openai_api_key'] = self.obfuscated_token(config.get('openai_api_key')) | |||
| return config | |||
| return config | |||
| def get_token_type(self): | |||
| return dict | |||
| def config_validate(self, config: Union[dict | str]): | |||
| """ | |||
| Validates the given config. | |||
| """ | |||
| try: | |||
| if not isinstance(config, dict): | |||
| raise ValueError('Config must be a object.') | |||
| if 'openai_api_version' not in config: | |||
| config['openai_api_version'] = AZURE_OPENAI_API_VERSION | |||
| self.check_embedding_model(credentials=config) | |||
| except ValidateFailedError as e: | |||
| raise e | |||
| except AzureAuthenticationError: | |||
| raise ValidateFailedError('Validation failed, please check your API Key.') | |||
| except AzureRequestFailedError as ex: | |||
| raise ValidateFailedError('Validation failed, error: {}.'.format(str(ex))) | |||
| except Exception as ex: | |||
| logging.exception('Azure OpenAI Credentials validation failed') | |||
| raise ValidateFailedError('Validation failed, error: {}.'.format(str(ex))) | |||
| def get_encrypted_token(self, config: Union[dict | str]): | |||
| """ | |||
| Returns the encrypted token. | |||
| """ | |||
| return json.dumps({ | |||
| 'openai_api_type': 'azure', | |||
| 'openai_api_version': AZURE_OPENAI_API_VERSION, | |||
| 'openai_api_base': config['openai_api_base'], | |||
| 'openai_api_key': self.encrypt_token(config['openai_api_key']) | |||
| }) | |||
| def get_decrypted_token(self, token: str): | |||
| """ | |||
| Returns the decrypted token. | |||
| """ | |||
| config = json.loads(token) | |||
| config['openai_api_key'] = self.decrypt_token(config['openai_api_key']) | |||
| return config | |||
| class AzureAuthenticationError(Exception): | |||
| pass | |||
| class AzureRequestFailedError(Exception): | |||
| pass | |||
| @@ -1,132 +0,0 @@ | |||
| import base64 | |||
| from abc import ABC, abstractmethod | |||
| from typing import Optional, Union | |||
| from core.constant import llm_constant | |||
| from core.llm.error import QuotaExceededError, ModelCurrentlyNotSupportError, ProviderTokenNotInitError | |||
| from extensions.ext_database import db | |||
| from libs import rsa | |||
| from models.account import Tenant | |||
| from models.provider import Provider, ProviderType, ProviderName | |||
| class BaseProvider(ABC): | |||
| def __init__(self, tenant_id: str): | |||
| self.tenant_id = tenant_id | |||
| def get_provider_api_key(self, model_id: Optional[str] = None, only_custom: bool = False) -> Union[str | dict]: | |||
| """ | |||
| Returns the decrypted API key for the given tenant_id and provider_name. | |||
| If the provider is of type SYSTEM and the quota is exceeded, raises a QuotaExceededError. | |||
| If the provider is not found or not valid, raises a ProviderTokenNotInitError. | |||
| """ | |||
| provider = self.get_provider(only_custom) | |||
| if not provider: | |||
| raise ProviderTokenNotInitError( | |||
| f"No valid {llm_constant.models[model_id]} model provider credentials found. " | |||
| f"Please go to Settings -> Model Provider to complete your provider credentials." | |||
| ) | |||
| if provider.provider_type == ProviderType.SYSTEM.value: | |||
| quota_used = provider.quota_used if provider.quota_used is not None else 0 | |||
| quota_limit = provider.quota_limit if provider.quota_limit is not None else 0 | |||
| if model_id and model_id == 'gpt-4': | |||
| raise ModelCurrentlyNotSupportError() | |||
| if quota_used >= quota_limit: | |||
| raise QuotaExceededError() | |||
| return self.get_hosted_credentials() | |||
| else: | |||
| return self.get_decrypted_token(provider.encrypted_config) | |||
| def get_provider(self, only_custom: bool = False) -> Optional[Provider]: | |||
| """ | |||
| Returns the Provider instance for the given tenant_id and provider_name. | |||
| If both CUSTOM and System providers exist, the preferred provider will be returned based on the prefer_custom flag. | |||
| """ | |||
| return BaseProvider.get_valid_provider(self.tenant_id, self.get_provider_name().value, only_custom) | |||
| @classmethod | |||
| def get_valid_provider(cls, tenant_id: str, provider_name: str = None, only_custom: bool = False) -> Optional[ | |||
| Provider]: | |||
| """ | |||
| Returns the Provider instance for the given tenant_id and provider_name. | |||
| If both CUSTOM and System providers exist. | |||
| """ | |||
| query = db.session.query(Provider).filter( | |||
| Provider.tenant_id == tenant_id | |||
| ) | |||
| if provider_name: | |||
| query = query.filter(Provider.provider_name == provider_name) | |||
| if only_custom: | |||
| query = query.filter(Provider.provider_type == ProviderType.CUSTOM.value) | |||
| providers = query.order_by(Provider.provider_type.asc()).all() | |||
| for provider in providers: | |||
| if provider.provider_type == ProviderType.CUSTOM.value and provider.is_valid and provider.encrypted_config: | |||
| return provider | |||
| elif provider.provider_type == ProviderType.SYSTEM.value and provider.is_valid: | |||
| return provider | |||
| return None | |||
| def get_hosted_credentials(self) -> Union[str | dict]: | |||
| raise ProviderTokenNotInitError( | |||
| f"No valid {self.get_provider_name().value} model provider credentials found. " | |||
| f"Please go to Settings -> Model Provider to complete your provider credentials." | |||
| ) | |||
| def get_provider_configs(self, obfuscated: bool = False, only_custom: bool = False) -> Union[str | dict]: | |||
| """ | |||
| Returns the provider configs. | |||
| """ | |||
| try: | |||
| config = self.get_provider_api_key(only_custom=only_custom) | |||
| except: | |||
| config = '' | |||
| if obfuscated: | |||
| return self.obfuscated_token(config) | |||
| return config | |||
| def obfuscated_token(self, token: str): | |||
| return token[:6] + '*' * (len(token) - 8) + token[-2:] | |||
| def get_token_type(self): | |||
| return str | |||
| def get_encrypted_token(self, config: Union[dict | str]): | |||
| return self.encrypt_token(config) | |||
| def get_decrypted_token(self, token: str): | |||
| return self.decrypt_token(token) | |||
| def encrypt_token(self, token): | |||
| tenant = db.session.query(Tenant).filter(Tenant.id == self.tenant_id).first() | |||
| encrypted_token = rsa.encrypt(token, tenant.encrypt_public_key) | |||
| return base64.b64encode(encrypted_token).decode() | |||
| def decrypt_token(self, token): | |||
| return rsa.decrypt(base64.b64decode(token), self.tenant_id) | |||
| @abstractmethod | |||
| def get_provider_name(self): | |||
| raise NotImplementedError | |||
| @abstractmethod | |||
| def get_credentials(self, model_id: Optional[str] = None) -> dict: | |||
| raise NotImplementedError | |||
| @abstractmethod | |||
| def get_models(self, model_id: Optional[str] = None) -> list[dict]: | |||
| raise NotImplementedError | |||
| @abstractmethod | |||
| def config_validate(self, config: str): | |||
| raise NotImplementedError | |||
| @@ -1,2 +0,0 @@ | |||
| class ValidateFailedError(Exception): | |||
| description = "Provider Validate failed" | |||
| @@ -1,22 +0,0 @@ | |||
| from typing import Optional | |||
| from core.llm.provider.base import BaseProvider | |||
| from models.provider import ProviderName | |||
| class HuggingfaceProvider(BaseProvider): | |||
| def get_models(self, model_id: Optional[str] = None) -> list[dict]: | |||
| credentials = self.get_credentials(model_id) | |||
| # todo | |||
| return [] | |||
| def get_credentials(self, model_id: Optional[str] = None) -> dict: | |||
| """ | |||
| Returns the API credentials for Huggingface as a dictionary, for the given tenant_id. | |||
| """ | |||
| return { | |||
| 'huggingface_api_key': self.get_provider_api_key(model_id=model_id) | |||
| } | |||
| def get_provider_name(self): | |||
| return ProviderName.HUGGINGFACEHUB | |||
| @@ -1,53 +0,0 @@ | |||
| from typing import Optional, Union | |||
| from core.llm.provider.anthropic_provider import AnthropicProvider | |||
| from core.llm.provider.azure_provider import AzureProvider | |||
| from core.llm.provider.base import BaseProvider | |||
| from core.llm.provider.huggingface_provider import HuggingfaceProvider | |||
| from core.llm.provider.openai_provider import OpenAIProvider | |||
| from models.provider import Provider | |||
| class LLMProviderService: | |||
| def __init__(self, tenant_id: str, provider_name: str): | |||
| self.provider = self.init_provider(tenant_id, provider_name) | |||
| def init_provider(self, tenant_id: str, provider_name: str) -> BaseProvider: | |||
| if provider_name == 'openai': | |||
| return OpenAIProvider(tenant_id) | |||
| elif provider_name == 'azure_openai': | |||
| return AzureProvider(tenant_id) | |||
| elif provider_name == 'anthropic': | |||
| return AnthropicProvider(tenant_id) | |||
| elif provider_name == 'huggingface': | |||
| return HuggingfaceProvider(tenant_id) | |||
| else: | |||
| raise Exception('provider {} not found'.format(provider_name)) | |||
| def get_models(self, model_id: Optional[str] = None) -> list[dict]: | |||
| return self.provider.get_models(model_id) | |||
| def get_credentials(self, model_id: Optional[str] = None) -> dict: | |||
| return self.provider.get_credentials(model_id) | |||
| def get_provider_configs(self, obfuscated: bool = False, only_custom: bool = False) -> Union[str | dict]: | |||
| return self.provider.get_provider_configs(obfuscated=obfuscated, only_custom=only_custom) | |||
| def get_provider_db_record(self) -> Optional[Provider]: | |||
| return self.provider.get_provider() | |||
| def config_validate(self, config: Union[dict | str]): | |||
| """ | |||
| Validates the given config. | |||
| :param config: | |||
| :raises: ValidateFailedError | |||
| """ | |||
| return self.provider.config_validate(config) | |||
| def get_token_type(self): | |||
| return self.provider.get_token_type() | |||
| def get_encrypted_token(self, config: Union[dict | str]): | |||
| return self.provider.get_encrypted_token(config) | |||
| @@ -1,55 +0,0 @@ | |||
| import logging | |||
| from typing import Optional, Union | |||
| import openai | |||
| from openai.error import AuthenticationError, OpenAIError | |||
| from core import hosted_llm_credentials | |||
| from core.llm.error import ProviderTokenNotInitError | |||
| from core.llm.moderation import Moderation | |||
| from core.llm.provider.base import BaseProvider | |||
| from core.llm.provider.errors import ValidateFailedError | |||
| from models.provider import ProviderName | |||
| class OpenAIProvider(BaseProvider): | |||
| def get_models(self, model_id: Optional[str] = None) -> list[dict]: | |||
| credentials = self.get_credentials(model_id) | |||
| response = openai.Model.list(**credentials) | |||
| return [{ | |||
| 'id': model['id'], | |||
| 'name': model['id'], | |||
| } for model in response['data']] | |||
| def get_credentials(self, model_id: Optional[str] = None) -> dict: | |||
| """ | |||
| Returns the credentials for the given tenant_id and provider_name. | |||
| """ | |||
| return { | |||
| 'openai_api_key': self.get_provider_api_key(model_id=model_id) | |||
| } | |||
| def get_provider_name(self): | |||
| return ProviderName.OPENAI | |||
| def config_validate(self, config: Union[dict | str]): | |||
| """ | |||
| Validates the given config. | |||
| """ | |||
| try: | |||
| Moderation(self.get_provider_name().value, config).moderate('test') | |||
| except (AuthenticationError, OpenAIError) as ex: | |||
| raise ValidateFailedError(str(ex)) | |||
| except Exception as ex: | |||
| logging.exception('OpenAI config validation failed') | |||
| raise ex | |||
| def get_hosted_credentials(self) -> Union[str | dict]: | |||
| if not hosted_llm_credentials.openai or not hosted_llm_credentials.openai.api_key: | |||
| raise ProviderTokenNotInitError( | |||
| f"No valid {self.get_provider_name().value} model provider credentials found. " | |||
| f"Please go to Settings -> Model Provider to complete your provider credentials." | |||
| ) | |||
| return hosted_llm_credentials.openai.api_key | |||
| @@ -1,62 +0,0 @@ | |||
| from typing import List, Optional, Any, Dict | |||
| from httpx import Timeout | |||
| from langchain.callbacks.manager import Callbacks | |||
| from langchain.chat_models import ChatAnthropic | |||
| from langchain.schema import BaseMessage, LLMResult, SystemMessage, AIMessage, HumanMessage, ChatMessage | |||
| from pydantic import root_validator | |||
| from core.llm.wrappers.anthropic_wrapper import handle_anthropic_exceptions | |||
| class StreamableChatAnthropic(ChatAnthropic): | |||
| """ | |||
| Wrapper around Anthropic's large language model. | |||
| """ | |||
| default_request_timeout: Optional[float] = Timeout(timeout=300.0, connect=5.0) | |||
| @root_validator() | |||
| def prepare_params(cls, values: Dict) -> Dict: | |||
| values['model_name'] = values.get('model') | |||
| values['max_tokens'] = values.get('max_tokens_to_sample') | |||
| return values | |||
| @handle_anthropic_exceptions | |||
| def generate( | |||
| self, | |||
| messages: List[List[BaseMessage]], | |||
| stop: Optional[List[str]] = None, | |||
| callbacks: Callbacks = None, | |||
| *, | |||
| tags: Optional[List[str]] = None, | |||
| metadata: Optional[Dict[str, Any]] = None, | |||
| **kwargs: Any, | |||
| ) -> LLMResult: | |||
| return super().generate(messages, stop, callbacks, tags=tags, metadata=metadata, **kwargs) | |||
| @classmethod | |||
| def get_kwargs_from_model_params(cls, params: dict): | |||
| params['model'] = params.get('model_name') | |||
| del params['model_name'] | |||
| params['max_tokens_to_sample'] = params.get('max_tokens') | |||
| del params['max_tokens'] | |||
| del params['frequency_penalty'] | |||
| del params['presence_penalty'] | |||
| return params | |||
| def _convert_one_message_to_text(self, message: BaseMessage) -> str: | |||
| if isinstance(message, ChatMessage): | |||
| message_text = f"\n\n{message.role.capitalize()}: {message.content}" | |||
| elif isinstance(message, HumanMessage): | |||
| message_text = f"{self.HUMAN_PROMPT} {message.content}" | |||
| elif isinstance(message, AIMessage): | |||
| message_text = f"{self.AI_PROMPT} {message.content}" | |||
| elif isinstance(message, SystemMessage): | |||
| message_text = f"<admin>{message.content}</admin>" | |||
| else: | |||
| raise ValueError(f"Got unknown type {message}") | |||
| return message_text | |||
| @@ -1,41 +0,0 @@ | |||
| import decimal | |||
| from typing import Optional | |||
| import tiktoken | |||
| from core.constant import llm_constant | |||
| class TokenCalculator: | |||
| @classmethod | |||
| def get_num_tokens(cls, model_name: str, text: str): | |||
| if len(text) == 0: | |||
| return 0 | |||
| enc = tiktoken.encoding_for_model(model_name) | |||
| tokenized_text = enc.encode(text) | |||
| # calculate the number of tokens in the encoded text | |||
| return len(tokenized_text) | |||
| @classmethod | |||
| def get_token_price(cls, model_name: str, tokens: int, text_type: Optional[str] = None) -> decimal.Decimal: | |||
| if model_name in llm_constant.models_by_mode['embedding']: | |||
| unit_price = llm_constant.model_prices[model_name]['usage'] | |||
| elif text_type == 'prompt': | |||
| unit_price = llm_constant.model_prices[model_name]['prompt'] | |||
| elif text_type == 'completion': | |||
| unit_price = llm_constant.model_prices[model_name]['completion'] | |||
| else: | |||
| raise Exception('Invalid text type') | |||
| tokens_per_1k = (decimal.Decimal(tokens) / 1000).quantize(decimal.Decimal('0.001'), | |||
| rounding=decimal.ROUND_HALF_UP) | |||
| total_price = tokens_per_1k * unit_price | |||
| return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP) | |||
| @classmethod | |||
| def get_currency(cls, model_name: str): | |||
| return llm_constant.model_currency | |||
| @@ -1,26 +0,0 @@ | |||
| import openai | |||
| from core.llm.wrappers.openai_wrapper import handle_openai_exceptions | |||
| from models.provider import ProviderName | |||
| from core.llm.provider.base import BaseProvider | |||
| class Whisper: | |||
| def __init__(self, provider: BaseProvider): | |||
| self.provider = provider | |||
| if self.provider.get_provider_name() == ProviderName.OPENAI: | |||
| self.client = openai.Audio | |||
| self.credentials = provider.get_credentials() | |||
| @handle_openai_exceptions | |||
| def transcribe(self, file): | |||
| return self.client.transcribe( | |||
| model='whisper-1', | |||
| file=file, | |||
| api_key=self.credentials.get('openai_api_key'), | |||
| api_base=self.credentials.get('openai_api_base'), | |||
| api_type=self.credentials.get('openai_api_type'), | |||
| api_version=self.credentials.get('openai_api_version'), | |||
| ) | |||
| @@ -1,27 +0,0 @@ | |||
| import logging | |||
| from functools import wraps | |||
| import anthropic | |||
| from core.llm.error import LLMAPIConnectionError, LLMAPIUnavailableError, LLMRateLimitError, LLMAuthorizationError, \ | |||
| LLMBadRequestError | |||
| def handle_anthropic_exceptions(func): | |||
| @wraps(func) | |||
| def wrapper(*args, **kwargs): | |||
| try: | |||
| return func(*args, **kwargs) | |||
| except anthropic.APIConnectionError as e: | |||
| logging.exception("Failed to connect to Anthropic API.") | |||
| raise LLMAPIConnectionError(f"Anthropic: The server could not be reached, cause: {e.__cause__}") | |||
| except anthropic.RateLimitError: | |||
| raise LLMRateLimitError("Anthropic: A 429 status code was received; we should back off a bit.") | |||
| except anthropic.AuthenticationError as e: | |||
| raise LLMAuthorizationError(f"Anthropic: {e.message}") | |||
| except anthropic.BadRequestError as e: | |||
| raise LLMBadRequestError(f"Anthropic: {e.message}") | |||
| except anthropic.APIStatusError as e: | |||
| raise LLMAPIUnavailableError(f"Anthropic: code: {e.status_code}, cause: {e.message}") | |||
| return wrapper | |||
| @@ -1,31 +0,0 @@ | |||
| import logging | |||
| from functools import wraps | |||
| import openai | |||
| from core.llm.error import LLMAPIConnectionError, LLMAPIUnavailableError, LLMRateLimitError, LLMAuthorizationError, \ | |||
| LLMBadRequestError | |||
| def handle_openai_exceptions(func): | |||
| @wraps(func) | |||
| def wrapper(*args, **kwargs): | |||
| try: | |||
| return func(*args, **kwargs) | |||
| except openai.error.InvalidRequestError as e: | |||
| logging.exception("Invalid request to OpenAI API.") | |||
| raise LLMBadRequestError(str(e)) | |||
| except openai.error.APIConnectionError as e: | |||
| logging.exception("Failed to connect to OpenAI API.") | |||
| raise LLMAPIConnectionError(e.__class__.__name__ + ":" + str(e)) | |||
| except (openai.error.APIError, openai.error.ServiceUnavailableError, openai.error.Timeout) as e: | |||
| logging.exception("OpenAI service unavailable.") | |||
| raise LLMAPIUnavailableError(e.__class__.__name__ + ":" + str(e)) | |||
| except openai.error.RateLimitError as e: | |||
| raise LLMRateLimitError(str(e)) | |||
| except openai.error.AuthenticationError as e: | |||
| raise LLMAuthorizationError(str(e)) | |||
| except openai.error.OpenAIError as e: | |||
| raise LLMBadRequestError(e.__class__.__name__ + ":" + str(e)) | |||
| return wrapper | |||
| @@ -1,10 +1,10 @@ | |||
| from typing import Any, List, Dict, Union | |||
| from typing import Any, List, Dict | |||
| from langchain.memory.chat_memory import BaseChatMemory | |||
| from langchain.schema import get_buffer_string, BaseMessage, HumanMessage, AIMessage, BaseLanguageModel | |||
| from langchain.schema import get_buffer_string, BaseMessage | |||
| from core.llm.streamable_chat_open_ai import StreamableChatOpenAI | |||
| from core.llm.streamable_open_ai import StreamableOpenAI | |||
| from core.model_providers.models.entity.message import PromptMessage, MessageType, to_lc_messages | |||
| from core.model_providers.models.llm.base import BaseLLM | |||
| from extensions.ext_database import db | |||
| from models.model import Conversation, Message | |||
| @@ -13,7 +13,7 @@ class ReadOnlyConversationTokenDBBufferSharedMemory(BaseChatMemory): | |||
| conversation: Conversation | |||
| human_prefix: str = "Human" | |||
| ai_prefix: str = "Assistant" | |||
| llm: BaseLanguageModel | |||
| model_instance: BaseLLM | |||
| memory_key: str = "chat_history" | |||
| max_token_limit: int = 2000 | |||
| message_limit: int = 10 | |||
| @@ -29,23 +29,23 @@ class ReadOnlyConversationTokenDBBufferSharedMemory(BaseChatMemory): | |||
| messages = list(reversed(messages)) | |||
| chat_messages: List[BaseMessage] = [] | |||
| chat_messages: List[PromptMessage] = [] | |||
| for message in messages: | |||
| chat_messages.append(HumanMessage(content=message.query)) | |||
| chat_messages.append(AIMessage(content=message.answer)) | |||
| chat_messages.append(PromptMessage(content=message.query, type=MessageType.HUMAN)) | |||
| chat_messages.append(PromptMessage(content=message.answer, type=MessageType.ASSISTANT)) | |||
| if not chat_messages: | |||
| return chat_messages | |||
| return [] | |||
| # prune the chat message if it exceeds the max token limit | |||
| curr_buffer_length = self.llm.get_num_tokens_from_messages(chat_messages) | |||
| curr_buffer_length = self.model_instance.get_num_tokens(chat_messages) | |||
| if curr_buffer_length > self.max_token_limit: | |||
| pruned_memory = [] | |||
| while curr_buffer_length > self.max_token_limit and chat_messages: | |||
| pruned_memory.append(chat_messages.pop(0)) | |||
| curr_buffer_length = self.llm.get_num_tokens_from_messages(chat_messages) | |||
| curr_buffer_length = self.model_instance.get_num_tokens(chat_messages) | |||
| return chat_messages | |||
| return to_lc_messages(chat_messages) | |||
| @property | |||
| def memory_variables(self) -> List[str]: | |||
| @@ -0,0 +1,293 @@ | |||
| from typing import Optional | |||
| from langchain.callbacks.base import Callbacks | |||
| from core.model_providers.error import ProviderTokenNotInitError, LLMBadRequestError | |||
| from core.model_providers.model_provider_factory import ModelProviderFactory, DEFAULT_MODELS | |||
| from core.model_providers.models.base import BaseProviderModel | |||
| from core.model_providers.models.embedding.base import BaseEmbedding | |||
| from core.model_providers.models.entity.model_params import ModelKwargs, ModelType | |||
| from core.model_providers.models.llm.base import BaseLLM | |||
| from core.model_providers.models.speech2text.base import BaseSpeech2Text | |||
| from extensions.ext_database import db | |||
| from models.provider import TenantDefaultModel | |||
| class ModelFactory: | |||
| @classmethod | |||
| def get_text_generation_model_from_model_config(cls, tenant_id: str, | |||
| model_config: dict, | |||
| streaming: bool = False, | |||
| callbacks: Callbacks = None) -> Optional[BaseLLM]: | |||
| provider_name = model_config.get("provider") | |||
| model_name = model_config.get("name") | |||
| completion_params = model_config.get("completion_params", {}) | |||
| return cls.get_text_generation_model( | |||
| tenant_id=tenant_id, | |||
| model_provider_name=provider_name, | |||
| model_name=model_name, | |||
| model_kwargs=ModelKwargs( | |||
| temperature=completion_params.get('temperature', 0), | |||
| max_tokens=completion_params.get('max_tokens', 256), | |||
| top_p=completion_params.get('top_p', 0), | |||
| frequency_penalty=completion_params.get('frequency_penalty', 0.1), | |||
| presence_penalty=completion_params.get('presence_penalty', 0.1) | |||
| ), | |||
| streaming=streaming, | |||
| callbacks=callbacks | |||
| ) | |||
| @classmethod | |||
| def get_text_generation_model(cls, | |||
| tenant_id: str, | |||
| model_provider_name: Optional[str] = None, | |||
| model_name: Optional[str] = None, | |||
| model_kwargs: Optional[ModelKwargs] = None, | |||
| streaming: bool = False, | |||
| callbacks: Callbacks = None) -> Optional[BaseLLM]: | |||
| """ | |||
| get text generation model. | |||
| :param tenant_id: a string representing the ID of the tenant. | |||
| :param model_provider_name: | |||
| :param model_name: | |||
| :param model_kwargs: | |||
| :param streaming: | |||
| :param callbacks: | |||
| :return: | |||
| """ | |||
| is_default_model = False | |||
| if model_provider_name is None and model_name is None: | |||
| default_model = cls.get_default_model(tenant_id, ModelType.TEXT_GENERATION) | |||
| if not default_model: | |||
| raise LLMBadRequestError(f"Default model is not available. " | |||
| f"Please configure a Default System Reasoning Model " | |||
| f"in the Settings -> Model Provider.") | |||
| model_provider_name = default_model.provider_name | |||
| model_name = default_model.model_name | |||
| is_default_model = True | |||
| # get model provider | |||
| model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name) | |||
| if not model_provider: | |||
| raise ProviderTokenNotInitError(f"Model {model_name} provider credentials is not initialized.") | |||
| # init text generation model | |||
| model_class = model_provider.get_model_class(model_type=ModelType.TEXT_GENERATION) | |||
| try: | |||
| model_instance = model_class( | |||
| model_provider=model_provider, | |||
| name=model_name, | |||
| model_kwargs=model_kwargs, | |||
| streaming=streaming, | |||
| callbacks=callbacks | |||
| ) | |||
| except LLMBadRequestError as e: | |||
| if is_default_model: | |||
| raise LLMBadRequestError(f"Default model {model_name} is not available. " | |||
| f"Please check your model provider credentials.") | |||
| else: | |||
| raise e | |||
| if is_default_model: | |||
| model_instance.deduct_quota = False | |||
| return model_instance | |||
| @classmethod | |||
| def get_embedding_model(cls, | |||
| tenant_id: str, | |||
| model_provider_name: Optional[str] = None, | |||
| model_name: Optional[str] = None) -> Optional[BaseEmbedding]: | |||
| """ | |||
| get embedding model. | |||
| :param tenant_id: a string representing the ID of the tenant. | |||
| :param model_provider_name: | |||
| :param model_name: | |||
| :return: | |||
| """ | |||
| if model_provider_name is None and model_name is None: | |||
| default_model = cls.get_default_model(tenant_id, ModelType.EMBEDDINGS) | |||
| if not default_model: | |||
| raise LLMBadRequestError(f"Default model is not available. " | |||
| f"Please configure a Default Embedding Model " | |||
| f"in the Settings -> Model Provider.") | |||
| model_provider_name = default_model.provider_name | |||
| model_name = default_model.model_name | |||
| # get model provider | |||
| model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name) | |||
| if not model_provider: | |||
| raise ProviderTokenNotInitError(f"Model {model_name} provider credentials is not initialized.") | |||
| # init embedding model | |||
| model_class = model_provider.get_model_class(model_type=ModelType.EMBEDDINGS) | |||
| return model_class( | |||
| model_provider=model_provider, | |||
| name=model_name | |||
| ) | |||
| @classmethod | |||
| def get_speech2text_model(cls, | |||
| tenant_id: str, | |||
| model_provider_name: Optional[str] = None, | |||
| model_name: Optional[str] = None) -> Optional[BaseSpeech2Text]: | |||
| """ | |||
| get speech to text model. | |||
| :param tenant_id: a string representing the ID of the tenant. | |||
| :param model_provider_name: | |||
| :param model_name: | |||
| :return: | |||
| """ | |||
| if model_provider_name is None and model_name is None: | |||
| default_model = cls.get_default_model(tenant_id, ModelType.SPEECH_TO_TEXT) | |||
| if not default_model: | |||
| raise LLMBadRequestError(f"Default model is not available. " | |||
| f"Please configure a Default Speech-to-Text Model " | |||
| f"in the Settings -> Model Provider.") | |||
| model_provider_name = default_model.provider_name | |||
| model_name = default_model.model_name | |||
| # get model provider | |||
| model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name) | |||
| if not model_provider: | |||
| raise ProviderTokenNotInitError(f"Model {model_name} provider credentials is not initialized.") | |||
| # init speech to text model | |||
| model_class = model_provider.get_model_class(model_type=ModelType.SPEECH_TO_TEXT) | |||
| return model_class( | |||
| model_provider=model_provider, | |||
| name=model_name | |||
| ) | |||
| @classmethod | |||
| def get_moderation_model(cls, | |||
| tenant_id: str, | |||
| model_provider_name: str, | |||
| model_name: str) -> Optional[BaseProviderModel]: | |||
| """ | |||
| get moderation model. | |||
| :param tenant_id: a string representing the ID of the tenant. | |||
| :param model_provider_name: | |||
| :param model_name: | |||
| :return: | |||
| """ | |||
| # get model provider | |||
| model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name) | |||
| if not model_provider: | |||
| raise ProviderTokenNotInitError(f"Model {model_name} provider credentials is not initialized.") | |||
| # init moderation model | |||
| model_class = model_provider.get_model_class(model_type=ModelType.MODERATION) | |||
| return model_class( | |||
| model_provider=model_provider, | |||
| name=model_name | |||
| ) | |||
| @classmethod | |||
| def get_default_model(cls, tenant_id: str, model_type: ModelType) -> TenantDefaultModel: | |||
| """ | |||
| get default model of model type. | |||
| :param tenant_id: | |||
| :param model_type: | |||
| :return: | |||
| """ | |||
| # get default model | |||
| default_model = db.session.query(TenantDefaultModel) \ | |||
| .filter( | |||
| TenantDefaultModel.tenant_id == tenant_id, | |||
| TenantDefaultModel.model_type == model_type.value | |||
| ).first() | |||
| if not default_model: | |||
| model_provider_rules = ModelProviderFactory.get_provider_rules() | |||
| for model_provider_name, model_provider_rule in model_provider_rules.items(): | |||
| model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name) | |||
| if not model_provider: | |||
| continue | |||
| model_list = model_provider.get_supported_model_list(model_type) | |||
| if model_list: | |||
| model_info = model_list[0] | |||
| default_model = TenantDefaultModel( | |||
| tenant_id=tenant_id, | |||
| model_type=model_type.value, | |||
| provider_name=model_provider_name, | |||
| model_name=model_info['id'] | |||
| ) | |||
| db.session.add(default_model) | |||
| db.session.commit() | |||
| break | |||
| return default_model | |||
| @classmethod | |||
| def update_default_model(cls, | |||
| tenant_id: str, | |||
| model_type: ModelType, | |||
| provider_name: str, | |||
| model_name: str) -> TenantDefaultModel: | |||
| """ | |||
| update default model of model type. | |||
| :param tenant_id: | |||
| :param model_type: | |||
| :param provider_name: | |||
| :param model_name: | |||
| :return: | |||
| """ | |||
| model_provider_name = ModelProviderFactory.get_provider_names() | |||
| if provider_name not in model_provider_name: | |||
| raise ValueError(f'Invalid provider name: {provider_name}') | |||
| model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, provider_name) | |||
| if not model_provider: | |||
| raise ProviderTokenNotInitError(f"Model {model_name} provider credentials is not initialized.") | |||
| model_list = model_provider.get_supported_model_list(model_type) | |||
| model_ids = [model['id'] for model in model_list] | |||
| if model_name not in model_ids: | |||
| raise ValueError(f'Invalid model name: {model_name}') | |||
| # get default model | |||
| default_model = db.session.query(TenantDefaultModel) \ | |||
| .filter( | |||
| TenantDefaultModel.tenant_id == tenant_id, | |||
| TenantDefaultModel.model_type == model_type.value | |||
| ).first() | |||
| if default_model: | |||
| # update default model | |||
| default_model.provider_name = provider_name | |||
| default_model.model_name = model_name | |||
| db.session.commit() | |||
| else: | |||
| # create default model | |||
| default_model = TenantDefaultModel( | |||
| tenant_id=tenant_id, | |||
| model_type=model_type.value, | |||
| provider_name=provider_name, | |||
| model_name=model_name, | |||
| ) | |||
| db.session.add(default_model) | |||
| db.session.commit() | |||
| return default_model | |||
| @@ -0,0 +1,228 @@ | |||
| from typing import Type | |||
| from sqlalchemy.exc import IntegrityError | |||
| from core.model_providers.models.entity.model_params import ModelType | |||
| from core.model_providers.providers.base import BaseModelProvider | |||
| from core.model_providers.rules import provider_rules | |||
| from extensions.ext_database import db | |||
| from models.provider import TenantPreferredModelProvider, ProviderType, Provider, ProviderQuotaType | |||
| DEFAULT_MODELS = { | |||
| ModelType.TEXT_GENERATION.value: { | |||
| 'provider_name': 'openai', | |||
| 'model_name': 'gpt-3.5-turbo', | |||
| }, | |||
| ModelType.EMBEDDINGS.value: { | |||
| 'provider_name': 'openai', | |||
| 'model_name': 'text-embedding-ada-002', | |||
| }, | |||
| ModelType.SPEECH_TO_TEXT.value: { | |||
| 'provider_name': 'openai', | |||
| 'model_name': 'whisper-1', | |||
| } | |||
| } | |||
| class ModelProviderFactory: | |||
| @classmethod | |||
| def get_model_provider_class(cls, provider_name: str) -> Type[BaseModelProvider]: | |||
| if provider_name == 'openai': | |||
| from core.model_providers.providers.openai_provider import OpenAIProvider | |||
| return OpenAIProvider | |||
| elif provider_name == 'anthropic': | |||
| from core.model_providers.providers.anthropic_provider import AnthropicProvider | |||
| return AnthropicProvider | |||
| elif provider_name == 'minimax': | |||
| from core.model_providers.providers.minimax_provider import MinimaxProvider | |||
| return MinimaxProvider | |||
| elif provider_name == 'spark': | |||
| from core.model_providers.providers.spark_provider import SparkProvider | |||
| return SparkProvider | |||
| elif provider_name == 'tongyi': | |||
| from core.model_providers.providers.tongyi_provider import TongyiProvider | |||
| return TongyiProvider | |||
| elif provider_name == 'wenxin': | |||
| from core.model_providers.providers.wenxin_provider import WenxinProvider | |||
| return WenxinProvider | |||
| elif provider_name == 'chatglm': | |||
| from core.model_providers.providers.chatglm_provider import ChatGLMProvider | |||
| return ChatGLMProvider | |||
| elif provider_name == 'azure_openai': | |||
| from core.model_providers.providers.azure_openai_provider import AzureOpenAIProvider | |||
| return AzureOpenAIProvider | |||
| elif provider_name == 'replicate': | |||
| from core.model_providers.providers.replicate_provider import ReplicateProvider | |||
| return ReplicateProvider | |||
| elif provider_name == 'huggingface_hub': | |||
| from core.model_providers.providers.huggingface_hub_provider import HuggingfaceHubProvider | |||
| return HuggingfaceHubProvider | |||
| else: | |||
| raise NotImplementedError | |||
| @classmethod | |||
| def get_provider_names(cls): | |||
| """ | |||
| Returns a list of provider names. | |||
| """ | |||
| return list(provider_rules.keys()) | |||
| @classmethod | |||
| def get_provider_rules(cls): | |||
| """ | |||
| Returns a list of provider rules. | |||
| :return: | |||
| """ | |||
| return provider_rules | |||
| @classmethod | |||
| def get_provider_rule(cls, provider_name: str): | |||
| """ | |||
| Returns provider rule. | |||
| """ | |||
| return provider_rules[provider_name] | |||
| @classmethod | |||
| def get_preferred_model_provider(cls, tenant_id: str, model_provider_name: str): | |||
| """ | |||
| get preferred model provider. | |||
| :param tenant_id: a string representing the ID of the tenant. | |||
| :param model_provider_name: | |||
| :return: | |||
| """ | |||
| # get preferred provider | |||
| preferred_provider = cls._get_preferred_provider(tenant_id, model_provider_name) | |||
| if not preferred_provider or not preferred_provider.is_valid: | |||
| return None | |||
| # init model provider | |||
| model_provider_class = ModelProviderFactory.get_model_provider_class(model_provider_name) | |||
| return model_provider_class(provider=preferred_provider) | |||
| @classmethod | |||
| def get_preferred_type_by_preferred_model_provider(cls, | |||
| tenant_id: str, | |||
| model_provider_name: str, | |||
| preferred_model_provider: TenantPreferredModelProvider): | |||
| """ | |||
| get preferred provider type by preferred model provider. | |||
| :param model_provider_name: | |||
| :param preferred_model_provider: | |||
| :return: | |||
| """ | |||
| if not preferred_model_provider: | |||
| model_provider_rules = ModelProviderFactory.get_provider_rule(model_provider_name) | |||
| support_provider_types = model_provider_rules['support_provider_types'] | |||
| if ProviderType.CUSTOM.value in support_provider_types: | |||
| custom_provider = db.session.query(Provider) \ | |||
| .filter( | |||
| Provider.tenant_id == tenant_id, | |||
| Provider.provider_name == model_provider_name, | |||
| Provider.provider_type == ProviderType.CUSTOM.value, | |||
| Provider.is_valid == True | |||
| ).first() | |||
| if custom_provider: | |||
| return ProviderType.CUSTOM.value | |||
| model_provider = cls.get_model_provider_class(model_provider_name) | |||
| if ProviderType.SYSTEM.value in support_provider_types \ | |||
| and model_provider.is_provider_type_system_supported(): | |||
| return ProviderType.SYSTEM.value | |||
| elif ProviderType.CUSTOM.value in support_provider_types: | |||
| return ProviderType.CUSTOM.value | |||
| else: | |||
| return preferred_model_provider.preferred_provider_type | |||
| @classmethod | |||
| def _get_preferred_provider(cls, tenant_id: str, model_provider_name: str): | |||
| """ | |||
| get preferred provider of tenant. | |||
| :param tenant_id: | |||
| :param model_provider_name: | |||
| :return: | |||
| """ | |||
| # get preferred provider type | |||
| preferred_provider_type = cls._get_preferred_provider_type(tenant_id, model_provider_name) | |||
| # get providers by preferred provider type | |||
| providers = db.session.query(Provider) \ | |||
| .filter( | |||
| Provider.tenant_id == tenant_id, | |||
| Provider.provider_name == model_provider_name, | |||
| Provider.provider_type == preferred_provider_type | |||
| ).all() | |||
| no_system_provider = False | |||
| if preferred_provider_type == ProviderType.SYSTEM.value: | |||
| quota_type_to_provider_dict = {} | |||
| for provider in providers: | |||
| quota_type_to_provider_dict[provider.quota_type] = provider | |||
| model_provider_rules = ModelProviderFactory.get_provider_rule(model_provider_name) | |||
| for quota_type_enum in ProviderQuotaType: | |||
| quota_type = quota_type_enum.value | |||
| if quota_type in model_provider_rules['system_config']['supported_quota_types'] \ | |||
| and quota_type in quota_type_to_provider_dict.keys(): | |||
| provider = quota_type_to_provider_dict[quota_type] | |||
| if provider.is_valid and provider.quota_limit > provider.quota_used: | |||
| return provider | |||
| no_system_provider = True | |||
| if no_system_provider: | |||
| providers = db.session.query(Provider) \ | |||
| .filter( | |||
| Provider.tenant_id == tenant_id, | |||
| Provider.provider_name == model_provider_name, | |||
| Provider.provider_type == ProviderType.CUSTOM.value | |||
| ).all() | |||
| if preferred_provider_type == ProviderType.CUSTOM.value or no_system_provider: | |||
| if providers: | |||
| return providers[0] | |||
| else: | |||
| try: | |||
| provider = Provider( | |||
| tenant_id=tenant_id, | |||
| provider_name=model_provider_name, | |||
| provider_type=ProviderType.CUSTOM.value, | |||
| is_valid=False | |||
| ) | |||
| db.session.add(provider) | |||
| db.session.commit() | |||
| except IntegrityError: | |||
| db.session.rollback() | |||
| provider = db.session.query(Provider) \ | |||
| .filter( | |||
| Provider.tenant_id == tenant_id, | |||
| Provider.provider_name == model_provider_name, | |||
| Provider.provider_type == ProviderType.CUSTOM.value | |||
| ).first() | |||
| return provider | |||
| return None | |||
| @classmethod | |||
| def _get_preferred_provider_type(cls, tenant_id: str, model_provider_name: str): | |||
| """ | |||
| get preferred provider type of tenant. | |||
| :param tenant_id: | |||
| :param model_provider_name: | |||
| :return: | |||
| """ | |||
| preferred_model_provider = db.session.query(TenantPreferredModelProvider) \ | |||
| .filter( | |||
| TenantPreferredModelProvider.tenant_id == tenant_id, | |||
| TenantPreferredModelProvider.provider_name == model_provider_name | |||
| ).first() | |||
| return cls.get_preferred_type_by_preferred_model_provider(tenant_id, model_provider_name, preferred_model_provider) | |||
| @@ -0,0 +1,22 @@ | |||
| from abc import ABC | |||
| from typing import Any | |||
| from core.model_providers.providers.base import BaseModelProvider | |||
| class BaseProviderModel(ABC): | |||
| _client: Any | |||
| _model_provider: BaseModelProvider | |||
| def __init__(self, model_provider: BaseModelProvider, client: Any): | |||
| self._model_provider = model_provider | |||
| self._client = client | |||
| @property | |||
| def client(self): | |||
| return self._client | |||
| @property | |||
| def model_provider(self): | |||
| return self._model_provider | |||
| @@ -0,0 +1,78 @@ | |||
| import decimal | |||
| import logging | |||
| import openai | |||
| import tiktoken | |||
| from langchain.embeddings import OpenAIEmbeddings | |||
| from core.model_providers.error import LLMBadRequestError, LLMAuthorizationError, LLMRateLimitError, \ | |||
| LLMAPIUnavailableError, LLMAPIConnectionError | |||
| from core.model_providers.models.embedding.base import BaseEmbedding | |||
| from core.model_providers.providers.base import BaseModelProvider | |||
| AZURE_OPENAI_API_VERSION = '2023-07-01-preview' | |||
| class AzureOpenAIEmbedding(BaseEmbedding): | |||
| def __init__(self, model_provider: BaseModelProvider, name: str): | |||
| self.credentials = model_provider.get_model_credentials( | |||
| model_name=name, | |||
| model_type=self.type | |||
| ) | |||
| client = OpenAIEmbeddings( | |||
| deployment=name, | |||
| openai_api_type='azure', | |||
| openai_api_version=AZURE_OPENAI_API_VERSION, | |||
| chunk_size=16, | |||
| max_retries=1, | |||
| **self.credentials | |||
| ) | |||
| super().__init__(model_provider, client, name) | |||
| def get_num_tokens(self, text: str) -> int: | |||
| """ | |||
| get num tokens of text. | |||
| :param text: | |||
| :return: | |||
| """ | |||
| if len(text) == 0: | |||
| return 0 | |||
| enc = tiktoken.encoding_for_model(self.credentials.get('base_model_name')) | |||
| tokenized_text = enc.encode(text) | |||
| # calculate the number of tokens in the encoded text | |||
| return len(tokenized_text) | |||
| def get_token_price(self, tokens: int): | |||
| tokens_per_1k = (decimal.Decimal(tokens) / 1000).quantize(decimal.Decimal('0.001'), | |||
| rounding=decimal.ROUND_HALF_UP) | |||
| total_price = tokens_per_1k * decimal.Decimal('0.0001') | |||
| return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP) | |||
| def get_currency(self): | |||
| return 'USD' | |||
| def handle_exceptions(self, ex: Exception) -> Exception: | |||
| if isinstance(ex, openai.error.InvalidRequestError): | |||
| logging.warning("Invalid request to Azure OpenAI API.") | |||
| return LLMBadRequestError(str(ex)) | |||
| elif isinstance(ex, openai.error.APIConnectionError): | |||
| logging.warning("Failed to connect to Azure OpenAI API.") | |||
| return LLMAPIConnectionError(ex.__class__.__name__ + ":" + str(ex)) | |||
| elif isinstance(ex, (openai.error.APIError, openai.error.ServiceUnavailableError, openai.error.Timeout)): | |||
| logging.warning("Azure OpenAI service unavailable.") | |||
| return LLMAPIUnavailableError(ex.__class__.__name__ + ":" + str(ex)) | |||
| elif isinstance(ex, openai.error.RateLimitError): | |||
| return LLMRateLimitError('Azure ' + str(ex)) | |||
| elif isinstance(ex, openai.error.AuthenticationError): | |||
| raise LLMAuthorizationError('Azure ' + str(ex)) | |||
| elif isinstance(ex, openai.error.OpenAIError): | |||
| return LLMBadRequestError('Azure ' + ex.__class__.__name__ + ":" + str(ex)) | |||
| else: | |||
| return ex | |||
| @@ -0,0 +1,40 @@ | |||
| from abc import abstractmethod | |||
| from typing import Any | |||
| import tiktoken | |||
| from langchain.schema.language_model import _get_token_ids_default_method | |||
| from core.model_providers.models.base import BaseProviderModel | |||
| from core.model_providers.models.entity.model_params import ModelType | |||
| from core.model_providers.providers.base import BaseModelProvider | |||
| class BaseEmbedding(BaseProviderModel): | |||
| name: str | |||
| type: ModelType = ModelType.EMBEDDINGS | |||
| def __init__(self, model_provider: BaseModelProvider, client: Any, name: str): | |||
| super().__init__(model_provider, client) | |||
| self.name = name | |||
| def get_num_tokens(self, text: str) -> int: | |||
| """ | |||
| get num tokens of text. | |||
| :param text: | |||
| :return: | |||
| """ | |||
| if len(text) == 0: | |||
| return 0 | |||
| return len(_get_token_ids_default_method(text)) | |||
| def get_token_price(self, tokens: int): | |||
| return 0 | |||
| def get_currency(self): | |||
| return 'USD' | |||
| @abstractmethod | |||
| def handle_exceptions(self, ex: Exception) -> Exception: | |||
| raise NotImplementedError | |||
| @@ -0,0 +1,35 @@ | |||
| import decimal | |||
| import logging | |||
| from langchain.embeddings import MiniMaxEmbeddings | |||
| from core.model_providers.error import LLMBadRequestError | |||
| from core.model_providers.models.embedding.base import BaseEmbedding | |||
| from core.model_providers.providers.base import BaseModelProvider | |||
| class MinimaxEmbedding(BaseEmbedding): | |||
| def __init__(self, model_provider: BaseModelProvider, name: str): | |||
| credentials = model_provider.get_model_credentials( | |||
| model_name=name, | |||
| model_type=self.type | |||
| ) | |||
| client = MiniMaxEmbeddings( | |||
| model=name, | |||
| **credentials | |||
| ) | |||
| super().__init__(model_provider, client, name) | |||
| def get_token_price(self, tokens: int): | |||
| return decimal.Decimal('0') | |||
| def get_currency(self): | |||
| return 'RMB' | |||
| def handle_exceptions(self, ex: Exception) -> Exception: | |||
| if isinstance(ex, ValueError): | |||
| return LLMBadRequestError(f"Minimax: {str(ex)}") | |||
| else: | |||
| return ex | |||
| @@ -0,0 +1,72 @@ | |||
| import decimal | |||
| import logging | |||
| import openai | |||
| import tiktoken | |||
| from langchain.embeddings import OpenAIEmbeddings | |||
| from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, \ | |||
| LLMRateLimitError, LLMAuthorizationError | |||
| from core.model_providers.models.embedding.base import BaseEmbedding | |||
| from core.model_providers.providers.base import BaseModelProvider | |||
| class OpenAIEmbedding(BaseEmbedding): | |||
| def __init__(self, model_provider: BaseModelProvider, name: str): | |||
| credentials = model_provider.get_model_credentials( | |||
| model_name=name, | |||
| model_type=self.type | |||
| ) | |||
| client = OpenAIEmbeddings( | |||
| max_retries=1, | |||
| **credentials | |||
| ) | |||
| super().__init__(model_provider, client, name) | |||
| def get_num_tokens(self, text: str) -> int: | |||
| """ | |||
| get num tokens of text. | |||
| :param text: | |||
| :return: | |||
| """ | |||
| if len(text) == 0: | |||
| return 0 | |||
| enc = tiktoken.encoding_for_model(self.name) | |||
| tokenized_text = enc.encode(text) | |||
| # calculate the number of tokens in the encoded text | |||
| return len(tokenized_text) | |||
| def get_token_price(self, tokens: int): | |||
| tokens_per_1k = (decimal.Decimal(tokens) / 1000).quantize(decimal.Decimal('0.001'), | |||
| rounding=decimal.ROUND_HALF_UP) | |||
| total_price = tokens_per_1k * decimal.Decimal('0.0001') | |||
| return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP) | |||
| def get_currency(self): | |||
| return 'USD' | |||
| def handle_exceptions(self, ex: Exception) -> Exception: | |||
| if isinstance(ex, openai.error.InvalidRequestError): | |||
| logging.warning("Invalid request to OpenAI API.") | |||
| return LLMBadRequestError(str(ex)) | |||
| elif isinstance(ex, openai.error.APIConnectionError): | |||
| logging.warning("Failed to connect to OpenAI API.") | |||
| return LLMAPIConnectionError(ex.__class__.__name__ + ":" + str(ex)) | |||
| elif isinstance(ex, (openai.error.APIError, openai.error.ServiceUnavailableError, openai.error.Timeout)): | |||
| logging.warning("OpenAI service unavailable.") | |||
| return LLMAPIUnavailableError(ex.__class__.__name__ + ":" + str(ex)) | |||
| elif isinstance(ex, openai.error.RateLimitError): | |||
| return LLMRateLimitError(str(ex)) | |||
| elif isinstance(ex, openai.error.AuthenticationError): | |||
| raise LLMAuthorizationError(str(ex)) | |||
| elif isinstance(ex, openai.error.OpenAIError): | |||
| return LLMBadRequestError(ex.__class__.__name__ + ":" + str(ex)) | |||
| else: | |||
| return ex | |||
| @@ -0,0 +1,36 @@ | |||
| import decimal | |||
| from replicate.exceptions import ModelError, ReplicateError | |||
| from core.model_providers.error import LLMBadRequestError | |||
| from core.model_providers.providers.base import BaseModelProvider | |||
| from core.third_party.langchain.embeddings.replicate_embedding import ReplicateEmbeddings | |||
| from core.model_providers.models.embedding.base import BaseEmbedding | |||
| class ReplicateEmbedding(BaseEmbedding): | |||
| def __init__(self, model_provider: BaseModelProvider, name: str): | |||
| credentials = model_provider.get_model_credentials( | |||
| model_name=name, | |||
| model_type=self.type | |||
| ) | |||
| client = ReplicateEmbeddings( | |||
| model=name + ':' + credentials.get('model_version'), | |||
| replicate_api_token=credentials.get('replicate_api_token') | |||
| ) | |||
| super().__init__(model_provider, client, name) | |||
| def get_token_price(self, tokens: int): | |||
| # replicate only pay for prediction seconds | |||
| return decimal.Decimal('0') | |||
| def get_currency(self): | |||
| return 'USD' | |||
| def handle_exceptions(self, ex: Exception) -> Exception: | |||
| if isinstance(ex, (ModelError, ReplicateError)): | |||
| return LLMBadRequestError(f"Replicate: {str(ex)}") | |||
| else: | |||
| return ex | |||
| @@ -0,0 +1,53 @@ | |||
| import enum | |||
| from langchain.schema import HumanMessage, AIMessage, SystemMessage, BaseMessage | |||
| from pydantic import BaseModel | |||
| class LLMRunResult(BaseModel): | |||
| content: str | |||
| prompt_tokens: int | |||
| completion_tokens: int | |||
| class MessageType(enum.Enum): | |||
| HUMAN = 'human' | |||
| ASSISTANT = 'assistant' | |||
| SYSTEM = 'system' | |||
| class PromptMessage(BaseModel): | |||
| type: MessageType = MessageType.HUMAN | |||
| content: str = '' | |||
| def to_lc_messages(messages: list[PromptMessage]): | |||
| lc_messages = [] | |||
| for message in messages: | |||
| if message.type == MessageType.HUMAN: | |||
| lc_messages.append(HumanMessage(content=message.content)) | |||
| elif message.type == MessageType.ASSISTANT: | |||
| lc_messages.append(AIMessage(content=message.content)) | |||
| elif message.type == MessageType.SYSTEM: | |||
| lc_messages.append(SystemMessage(content=message.content)) | |||
| return lc_messages | |||
| def to_prompt_messages(messages: list[BaseMessage]): | |||
| prompt_messages = [] | |||
| for message in messages: | |||
| if isinstance(message, HumanMessage): | |||
| prompt_messages.append(PromptMessage(content=message.content, type=MessageType.HUMAN)) | |||
| elif isinstance(message, AIMessage): | |||
| prompt_messages.append(PromptMessage(content=message.content, type=MessageType.ASSISTANT)) | |||
| elif isinstance(message, SystemMessage): | |||
| prompt_messages.append(PromptMessage(content=message.content, type=MessageType.SYSTEM)) | |||
| return prompt_messages | |||
| def str_to_prompt_messages(texts: list[str]): | |||
| prompt_messages = [] | |||
| for text in texts: | |||
| prompt_messages.append(PromptMessage(content=text)) | |||
| return prompt_messages | |||
| @@ -0,0 +1,59 @@ | |||
| import enum | |||
| from typing import Optional, TypeVar, Generic | |||
| from langchain.load.serializable import Serializable | |||
| from pydantic import BaseModel | |||
| class ModelMode(enum.Enum): | |||
| COMPLETION = 'completion' | |||
| CHAT = 'chat' | |||
| class ModelType(enum.Enum): | |||
| TEXT_GENERATION = 'text-generation' | |||
| EMBEDDINGS = 'embeddings' | |||
| SPEECH_TO_TEXT = 'speech2text' | |||
| IMAGE = 'image' | |||
| VIDEO = 'video' | |||
| MODERATION = 'moderation' | |||
| @staticmethod | |||
| def value_of(value): | |||
| for member in ModelType: | |||
| if member.value == value: | |||
| return member | |||
| raise ValueError(f"No matching enum found for value '{value}'") | |||
| class ModelKwargs(BaseModel): | |||
| max_tokens: Optional[int] | |||
| temperature: Optional[float] | |||
| top_p: Optional[float] | |||
| presence_penalty: Optional[float] | |||
| frequency_penalty: Optional[float] | |||
| class KwargRuleType(enum.Enum): | |||
| STRING = 'string' | |||
| INTEGER = 'integer' | |||
| FLOAT = 'float' | |||
| T = TypeVar('T') | |||
| class KwargRule(Generic[T], BaseModel): | |||
| enabled: bool = True | |||
| min: Optional[T] = None | |||
| max: Optional[T] = None | |||
| default: Optional[T] = None | |||
| alias: Optional[str] = None | |||
| class ModelKwargsRules(BaseModel): | |||
| max_tokens: KwargRule = KwargRule[int](enabled=False) | |||
| temperature: KwargRule = KwargRule[float](enabled=False) | |||
| top_p: KwargRule = KwargRule[float](enabled=False) | |||
| presence_penalty: KwargRule = KwargRule[float](enabled=False) | |||
| frequency_penalty: KwargRule = KwargRule[float](enabled=False) | |||
| @@ -0,0 +1,10 @@ | |||
| from enum import Enum | |||
| class ProviderQuotaUnit(Enum): | |||
| TIMES = 'times' | |||
| TOKENS = 'tokens' | |||
| class ModelFeature(Enum): | |||
| AGENT_THOUGHT = 'agent_thought' | |||
| @@ -0,0 +1,107 @@ | |||
| import decimal | |||
| import logging | |||
| from functools import wraps | |||
| from typing import List, Optional, Any | |||
| import anthropic | |||
| from langchain.callbacks.manager import Callbacks | |||
| from langchain.chat_models import ChatAnthropic | |||
| from langchain.schema import LLMResult | |||
| from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, \ | |||
| LLMRateLimitError, LLMAuthorizationError | |||
| from core.model_providers.models.llm.base import BaseLLM | |||
| from core.model_providers.models.entity.message import PromptMessage, MessageType | |||
| from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs | |||
| class AnthropicModel(BaseLLM): | |||
| model_mode: ModelMode = ModelMode.CHAT | |||
| def _init_client(self) -> Any: | |||
| provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs) | |||
| return ChatAnthropic( | |||
| model=self.name, | |||
| streaming=self.streaming, | |||
| callbacks=self.callbacks, | |||
| default_request_timeout=60, | |||
| **self.credentials, | |||
| **provider_model_kwargs | |||
| ) | |||
| def _run(self, messages: List[PromptMessage], | |||
| stop: Optional[List[str]] = None, | |||
| callbacks: Callbacks = None, | |||
| **kwargs) -> LLMResult: | |||
| """ | |||
| run predict by prompt messages and stop words. | |||
| :param messages: | |||
| :param stop: | |||
| :param callbacks: | |||
| :return: | |||
| """ | |||
| prompts = self._get_prompt_from_messages(messages) | |||
| return self._client.generate([prompts], stop, callbacks) | |||
| def get_num_tokens(self, messages: List[PromptMessage]) -> int: | |||
| """ | |||
| get num tokens of prompt messages. | |||
| :param messages: | |||
| :return: | |||
| """ | |||
| prompts = self._get_prompt_from_messages(messages) | |||
| return max(self._client.get_num_tokens_from_messages(prompts) - len(prompts), 0) | |||
| def get_token_price(self, tokens: int, message_type: MessageType): | |||
| model_unit_prices = { | |||
| 'claude-instant-1': { | |||
| 'prompt': decimal.Decimal('1.63'), | |||
| 'completion': decimal.Decimal('5.51'), | |||
| }, | |||
| 'claude-2': { | |||
| 'prompt': decimal.Decimal('11.02'), | |||
| 'completion': decimal.Decimal('32.68'), | |||
| }, | |||
| } | |||
| if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM: | |||
| unit_price = model_unit_prices[self.name]['prompt'] | |||
| else: | |||
| unit_price = model_unit_prices[self.name]['completion'] | |||
| tokens_per_1m = (decimal.Decimal(tokens) / 1000000).quantize(decimal.Decimal('0.000001'), | |||
| rounding=decimal.ROUND_HALF_UP) | |||
| total_price = tokens_per_1m * unit_price | |||
| return total_price.quantize(decimal.Decimal('0.00000001'), rounding=decimal.ROUND_HALF_UP) | |||
| def get_currency(self): | |||
| return 'USD' | |||
| def _set_model_kwargs(self, model_kwargs: ModelKwargs): | |||
| provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs) | |||
| for k, v in provider_model_kwargs.items(): | |||
| if hasattr(self.client, k): | |||
| setattr(self.client, k, v) | |||
| def handle_exceptions(self, ex: Exception) -> Exception: | |||
| if isinstance(ex, anthropic.APIConnectionError): | |||
| logging.warning("Failed to connect to Anthropic API.") | |||
| return LLMAPIConnectionError(f"Anthropic: The server could not be reached, cause: {ex.__cause__}") | |||
| elif isinstance(ex, anthropic.RateLimitError): | |||
| return LLMRateLimitError("Anthropic: A 429 status code was received; we should back off a bit.") | |||
| elif isinstance(ex, anthropic.AuthenticationError): | |||
| return LLMAuthorizationError(f"Anthropic: {ex.message}") | |||
| elif isinstance(ex, anthropic.BadRequestError): | |||
| return LLMBadRequestError(f"Anthropic: {ex.message}") | |||
| elif isinstance(ex, anthropic.APIStatusError): | |||
| return LLMAPIUnavailableError(f"Anthropic: code: {ex.status_code}, cause: {ex.message}") | |||
| else: | |||
| return ex | |||
| @classmethod | |||
| def support_streaming(cls): | |||
| return True | |||
| @@ -0,0 +1,177 @@ | |||
| import decimal | |||
| import logging | |||
| from functools import wraps | |||
| from typing import List, Optional, Any | |||
| import openai | |||
| from langchain.callbacks.manager import Callbacks | |||
| from langchain.schema import LLMResult | |||
| from core.model_providers.providers.base import BaseModelProvider | |||
| from core.third_party.langchain.llms.azure_chat_open_ai import EnhanceAzureChatOpenAI | |||
| from core.third_party.langchain.llms.azure_open_ai import EnhanceAzureOpenAI | |||
| from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, \ | |||
| LLMRateLimitError, LLMAuthorizationError | |||
| from core.model_providers.models.llm.base import BaseLLM | |||
| from core.model_providers.models.entity.message import PromptMessage, MessageType | |||
| from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs | |||
| AZURE_OPENAI_API_VERSION = '2023-07-01-preview' | |||
| class AzureOpenAIModel(BaseLLM): | |||
| def __init__(self, model_provider: BaseModelProvider, | |||
| name: str, | |||
| model_kwargs: ModelKwargs, | |||
| streaming: bool = False, | |||
| callbacks: Callbacks = None): | |||
| if name == 'text-davinci-003': | |||
| self.model_mode = ModelMode.COMPLETION | |||
| else: | |||
| self.model_mode = ModelMode.CHAT | |||
| super().__init__(model_provider, name, model_kwargs, streaming, callbacks) | |||
| def _init_client(self) -> Any: | |||
| provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs) | |||
| if self.name == 'text-davinci-003': | |||
| client = EnhanceAzureOpenAI( | |||
| deployment_name=self.name, | |||
| streaming=self.streaming, | |||
| request_timeout=60, | |||
| openai_api_type='azure', | |||
| openai_api_version=AZURE_OPENAI_API_VERSION, | |||
| openai_api_key=self.credentials.get('openai_api_key'), | |||
| openai_api_base=self.credentials.get('openai_api_base'), | |||
| callbacks=self.callbacks, | |||
| **provider_model_kwargs | |||
| ) | |||
| else: | |||
| extra_model_kwargs = { | |||
| 'top_p': provider_model_kwargs.get('top_p'), | |||
| 'frequency_penalty': provider_model_kwargs.get('frequency_penalty'), | |||
| 'presence_penalty': provider_model_kwargs.get('presence_penalty'), | |||
| } | |||
| client = EnhanceAzureChatOpenAI( | |||
| deployment_name=self.name, | |||
| temperature=provider_model_kwargs.get('temperature'), | |||
| max_tokens=provider_model_kwargs.get('max_tokens'), | |||
| model_kwargs=extra_model_kwargs, | |||
| streaming=self.streaming, | |||
| request_timeout=60, | |||
| openai_api_type='azure', | |||
| openai_api_version=AZURE_OPENAI_API_VERSION, | |||
| openai_api_key=self.credentials.get('openai_api_key'), | |||
| openai_api_base=self.credentials.get('openai_api_base'), | |||
| callbacks=self.callbacks, | |||
| ) | |||
| return client | |||
| def _run(self, messages: List[PromptMessage], | |||
| stop: Optional[List[str]] = None, | |||
| callbacks: Callbacks = None, | |||
| **kwargs) -> LLMResult: | |||
| """ | |||
| run predict by prompt messages and stop words. | |||
| :param messages: | |||
| :param stop: | |||
| :param callbacks: | |||
| :return: | |||
| """ | |||
| prompts = self._get_prompt_from_messages(messages) | |||
| return self._client.generate([prompts], stop, callbacks) | |||
| def get_num_tokens(self, messages: List[PromptMessage]) -> int: | |||
| """ | |||
| get num tokens of prompt messages. | |||
| :param messages: | |||
| :return: | |||
| """ | |||
| prompts = self._get_prompt_from_messages(messages) | |||
| if isinstance(prompts, str): | |||
| return self._client.get_num_tokens(prompts) | |||
| else: | |||
| return max(self._client.get_num_tokens_from_messages(prompts) - len(prompts), 0) | |||
| def get_token_price(self, tokens: int, message_type: MessageType): | |||
| model_unit_prices = { | |||
| 'gpt-4': { | |||
| 'prompt': decimal.Decimal('0.03'), | |||
| 'completion': decimal.Decimal('0.06'), | |||
| }, | |||
| 'gpt-4-32k': { | |||
| 'prompt': decimal.Decimal('0.06'), | |||
| 'completion': decimal.Decimal('0.12') | |||
| }, | |||
| 'gpt-35-turbo': { | |||
| 'prompt': decimal.Decimal('0.0015'), | |||
| 'completion': decimal.Decimal('0.002') | |||
| }, | |||
| 'gpt-35-turbo-16k': { | |||
| 'prompt': decimal.Decimal('0.003'), | |||
| 'completion': decimal.Decimal('0.004') | |||
| }, | |||
| 'text-davinci-003': { | |||
| 'prompt': decimal.Decimal('0.02'), | |||
| 'completion': decimal.Decimal('0.02') | |||
| }, | |||
| } | |||
| base_model_name = self.credentials.get("base_model_name") | |||
| if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM: | |||
| unit_price = model_unit_prices[base_model_name]['prompt'] | |||
| else: | |||
| unit_price = model_unit_prices[base_model_name]['completion'] | |||
| tokens_per_1k = (decimal.Decimal(tokens) / 1000).quantize(decimal.Decimal('0.001'), | |||
| rounding=decimal.ROUND_HALF_UP) | |||
| total_price = tokens_per_1k * unit_price | |||
| return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP) | |||
| def get_currency(self): | |||
| return 'USD' | |||
| def _set_model_kwargs(self, model_kwargs: ModelKwargs): | |||
| provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs) | |||
| if self.name == 'text-davinci-003': | |||
| for k, v in provider_model_kwargs.items(): | |||
| if hasattr(self.client, k): | |||
| setattr(self.client, k, v) | |||
| else: | |||
| extra_model_kwargs = { | |||
| 'top_p': provider_model_kwargs.get('top_p'), | |||
| 'frequency_penalty': provider_model_kwargs.get('frequency_penalty'), | |||
| 'presence_penalty': provider_model_kwargs.get('presence_penalty'), | |||
| } | |||
| self.client.temperature = provider_model_kwargs.get('temperature') | |||
| self.client.max_tokens = provider_model_kwargs.get('max_tokens') | |||
| self.client.model_kwargs = extra_model_kwargs | |||
| def handle_exceptions(self, ex: Exception) -> Exception: | |||
| if isinstance(ex, openai.error.InvalidRequestError): | |||
| logging.warning("Invalid request to Azure OpenAI API.") | |||
| return LLMBadRequestError(str(ex)) | |||
| elif isinstance(ex, openai.error.APIConnectionError): | |||
| logging.warning("Failed to connect to Azure OpenAI API.") | |||
| return LLMAPIConnectionError(ex.__class__.__name__ + ":" + str(ex)) | |||
| elif isinstance(ex, (openai.error.APIError, openai.error.ServiceUnavailableError, openai.error.Timeout)): | |||
| logging.warning("Azure OpenAI service unavailable.") | |||
| return LLMAPIUnavailableError(ex.__class__.__name__ + ":" + str(ex)) | |||
| elif isinstance(ex, openai.error.RateLimitError): | |||
| return LLMRateLimitError('Azure ' + str(ex)) | |||
| elif isinstance(ex, openai.error.AuthenticationError): | |||
| raise LLMAuthorizationError('Azure ' + str(ex)) | |||
| elif isinstance(ex, openai.error.OpenAIError): | |||
| return LLMBadRequestError('Azure ' + ex.__class__.__name__ + ":" + str(ex)) | |||
| else: | |||
| return ex | |||
| @classmethod | |||
| def support_streaming(cls): | |||
| return True | |||
| @@ -0,0 +1,269 @@ | |||
| from abc import abstractmethod | |||
| from typing import List, Optional, Any, Union | |||
| from langchain.callbacks.manager import Callbacks | |||
| from langchain.schema import LLMResult, SystemMessage, AIMessage, HumanMessage, BaseMessage, ChatGeneration | |||
| from core.callback_handler.std_out_callback_handler import DifyStreamingStdOutCallbackHandler, DifyStdOutCallbackHandler | |||
| from core.model_providers.models.base import BaseProviderModel | |||
| from core.model_providers.models.entity.message import PromptMessage, MessageType, LLMRunResult | |||
| from core.model_providers.models.entity.model_params import ModelType, ModelKwargs, ModelMode, ModelKwargsRules | |||
| from core.model_providers.providers.base import BaseModelProvider | |||
| from core.third_party.langchain.llms.fake import FakeLLM | |||
| class BaseLLM(BaseProviderModel): | |||
| model_mode: ModelMode = ModelMode.COMPLETION | |||
| name: str | |||
| model_kwargs: ModelKwargs | |||
| credentials: dict | |||
| streaming: bool = False | |||
| type: ModelType = ModelType.TEXT_GENERATION | |||
| deduct_quota: bool = True | |||
| def __init__(self, model_provider: BaseModelProvider, | |||
| name: str, | |||
| model_kwargs: ModelKwargs, | |||
| streaming: bool = False, | |||
| callbacks: Callbacks = None): | |||
| self.name = name | |||
| self.model_rules = model_provider.get_model_parameter_rules(name, self.type) | |||
| self.model_kwargs = model_kwargs if model_kwargs else ModelKwargs( | |||
| max_tokens=None, | |||
| temperature=None, | |||
| top_p=None, | |||
| presence_penalty=None, | |||
| frequency_penalty=None | |||
| ) | |||
| self.credentials = model_provider.get_model_credentials( | |||
| model_name=name, | |||
| model_type=self.type | |||
| ) | |||
| self.streaming = streaming | |||
| if streaming: | |||
| default_callback = DifyStreamingStdOutCallbackHandler() | |||
| else: | |||
| default_callback = DifyStdOutCallbackHandler() | |||
| if not callbacks: | |||
| callbacks = [default_callback] | |||
| else: | |||
| callbacks.append(default_callback) | |||
| self.callbacks = callbacks | |||
| client = self._init_client() | |||
| super().__init__(model_provider, client) | |||
| @abstractmethod | |||
| def _init_client(self) -> Any: | |||
| raise NotImplementedError | |||
| def run(self, messages: List[PromptMessage], | |||
| stop: Optional[List[str]] = None, | |||
| callbacks: Callbacks = None, | |||
| **kwargs) -> LLMRunResult: | |||
| """ | |||
| run predict by prompt messages and stop words. | |||
| :param messages: | |||
| :param stop: | |||
| :param callbacks: | |||
| :return: | |||
| """ | |||
| if self.deduct_quota: | |||
| self.model_provider.check_quota_over_limit() | |||
| if not callbacks: | |||
| callbacks = self.callbacks | |||
| else: | |||
| callbacks.extend(self.callbacks) | |||
| if 'fake_response' in kwargs and kwargs['fake_response']: | |||
| prompts = self._get_prompt_from_messages(messages, ModelMode.CHAT) | |||
| fake_llm = FakeLLM( | |||
| response=kwargs['fake_response'], | |||
| num_token_func=self.get_num_tokens, | |||
| streaming=self.streaming, | |||
| callbacks=callbacks | |||
| ) | |||
| result = fake_llm.generate([prompts]) | |||
| else: | |||
| try: | |||
| result = self._run( | |||
| messages=messages, | |||
| stop=stop, | |||
| callbacks=callbacks if not (self.streaming and not self.support_streaming()) else None, | |||
| **kwargs | |||
| ) | |||
| except Exception as ex: | |||
| raise self.handle_exceptions(ex) | |||
| if isinstance(result.generations[0][0], ChatGeneration): | |||
| completion_content = result.generations[0][0].message.content | |||
| else: | |||
| completion_content = result.generations[0][0].text | |||
| if self.streaming and not self.support_streaming(): | |||
| # use FakeLLM to simulate streaming when current model not support streaming but streaming is True | |||
| prompts = self._get_prompt_from_messages(messages, ModelMode.CHAT) | |||
| fake_llm = FakeLLM( | |||
| response=completion_content, | |||
| num_token_func=self.get_num_tokens, | |||
| streaming=self.streaming, | |||
| callbacks=callbacks | |||
| ) | |||
| fake_llm.generate([prompts]) | |||
| if result.llm_output and result.llm_output['token_usage']: | |||
| prompt_tokens = result.llm_output['token_usage']['prompt_tokens'] | |||
| completion_tokens = result.llm_output['token_usage']['completion_tokens'] | |||
| total_tokens = result.llm_output['token_usage']['total_tokens'] | |||
| else: | |||
| prompt_tokens = self.get_num_tokens(messages) | |||
| completion_tokens = self.get_num_tokens([PromptMessage(content=completion_content, type=MessageType.ASSISTANT)]) | |||
| total_tokens = prompt_tokens + completion_tokens | |||
| if self.deduct_quota: | |||
| self.model_provider.deduct_quota(total_tokens) | |||
| return LLMRunResult( | |||
| content=completion_content, | |||
| prompt_tokens=prompt_tokens, | |||
| completion_tokens=completion_tokens | |||
| ) | |||
| @abstractmethod | |||
| def _run(self, messages: List[PromptMessage], | |||
| stop: Optional[List[str]] = None, | |||
| callbacks: Callbacks = None, | |||
| **kwargs) -> LLMResult: | |||
| """ | |||
| run predict by prompt messages and stop words. | |||
| :param messages: | |||
| :param stop: | |||
| :param callbacks: | |||
| :return: | |||
| """ | |||
| raise NotImplementedError | |||
| @abstractmethod | |||
| def get_num_tokens(self, messages: List[PromptMessage]) -> int: | |||
| """ | |||
| get num tokens of prompt messages. | |||
| :param messages: | |||
| :return: | |||
| """ | |||
| raise NotImplementedError | |||
| @abstractmethod | |||
| def get_token_price(self, tokens: int, message_type: MessageType): | |||
| """ | |||
| get token price. | |||
| :param tokens: | |||
| :param message_type: | |||
| :return: | |||
| """ | |||
| raise NotImplementedError | |||
| @abstractmethod | |||
| def get_currency(self): | |||
| """ | |||
| get token currency. | |||
| :return: | |||
| """ | |||
| raise NotImplementedError | |||
| def get_model_kwargs(self): | |||
| return self.model_kwargs | |||
| def set_model_kwargs(self, model_kwargs: ModelKwargs): | |||
| self.model_kwargs = model_kwargs | |||
| self._set_model_kwargs(model_kwargs) | |||
| @abstractmethod | |||
| def _set_model_kwargs(self, model_kwargs: ModelKwargs): | |||
| raise NotImplementedError | |||
| @abstractmethod | |||
| def handle_exceptions(self, ex: Exception) -> Exception: | |||
| """ | |||
| Handle llm run exceptions. | |||
| :param ex: | |||
| :return: | |||
| """ | |||
| raise NotImplementedError | |||
| def add_callbacks(self, callbacks: Callbacks): | |||
| """ | |||
| Add callbacks to client. | |||
| :param callbacks: | |||
| :return: | |||
| """ | |||
| if not self.client.callbacks: | |||
| self.client.callbacks = callbacks | |||
| else: | |||
| self.client.callbacks.extend(callbacks) | |||
| @classmethod | |||
| def support_streaming(cls): | |||
| return False | |||
| def _get_prompt_from_messages(self, messages: List[PromptMessage], | |||
| model_mode: Optional[ModelMode] = None) -> Union[str | List[BaseMessage]]: | |||
| if len(messages) == 0: | |||
| raise ValueError("prompt must not be empty.") | |||
| if not model_mode: | |||
| model_mode = self.model_mode | |||
| if model_mode == ModelMode.COMPLETION: | |||
| return messages[0].content | |||
| else: | |||
| chat_messages = [] | |||
| for message in messages: | |||
| if message.type == MessageType.HUMAN: | |||
| chat_messages.append(HumanMessage(content=message.content)) | |||
| elif message.type == MessageType.ASSISTANT: | |||
| chat_messages.append(AIMessage(content=message.content)) | |||
| elif message.type == MessageType.SYSTEM: | |||
| chat_messages.append(SystemMessage(content=message.content)) | |||
| return chat_messages | |||
| def _to_model_kwargs_input(self, model_rules: ModelKwargsRules, model_kwargs: ModelKwargs) -> dict: | |||
| """ | |||
| convert model kwargs to provider model kwargs. | |||
| :param model_rules: | |||
| :param model_kwargs: | |||
| :return: | |||
| """ | |||
| model_kwargs_input = {} | |||
| for key, value in model_kwargs.dict().items(): | |||
| rule = getattr(model_rules, key) | |||
| if not rule.enabled: | |||
| continue | |||
| if rule.alias: | |||
| key = rule.alias | |||
| if rule.default is not None and value is None: | |||
| value = rule.default | |||
| if rule.min is not None: | |||
| value = max(value, rule.min) | |||
| if rule.max is not None: | |||
| value = min(value, rule.max) | |||
| model_kwargs_input[key] = value | |||
| return model_kwargs_input | |||
| @@ -0,0 +1,70 @@ | |||
| import decimal | |||
| from typing import List, Optional, Any | |||
| from langchain.callbacks.manager import Callbacks | |||
| from langchain.llms import ChatGLM | |||
| from langchain.schema import LLMResult | |||
| from core.model_providers.error import LLMBadRequestError | |||
| from core.model_providers.models.llm.base import BaseLLM | |||
| from core.model_providers.models.entity.message import PromptMessage, MessageType | |||
| from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs | |||
| class ChatGLMModel(BaseLLM): | |||
| model_mode: ModelMode = ModelMode.COMPLETION | |||
| def _init_client(self) -> Any: | |||
| provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs) | |||
| return ChatGLM( | |||
| callbacks=self.callbacks, | |||
| endpoint_url=self.credentials.get('api_base'), | |||
| **provider_model_kwargs | |||
| ) | |||
| def _run(self, messages: List[PromptMessage], | |||
| stop: Optional[List[str]] = None, | |||
| callbacks: Callbacks = None, | |||
| **kwargs) -> LLMResult: | |||
| """ | |||
| run predict by prompt messages and stop words. | |||
| :param messages: | |||
| :param stop: | |||
| :param callbacks: | |||
| :return: | |||
| """ | |||
| prompts = self._get_prompt_from_messages(messages) | |||
| return self._client.generate([prompts], stop, callbacks) | |||
| def get_num_tokens(self, messages: List[PromptMessage]) -> int: | |||
| """ | |||
| get num tokens of prompt messages. | |||
| :param messages: | |||
| :return: | |||
| """ | |||
| prompts = self._get_prompt_from_messages(messages) | |||
| return max(self._client.get_num_tokens(prompts), 0) | |||
| def get_token_price(self, tokens: int, message_type: MessageType): | |||
| return decimal.Decimal('0') | |||
| def get_currency(self): | |||
| return 'RMB' | |||
| def _set_model_kwargs(self, model_kwargs: ModelKwargs): | |||
| provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs) | |||
| for k, v in provider_model_kwargs.items(): | |||
| if hasattr(self.client, k): | |||
| setattr(self.client, k, v) | |||
| def handle_exceptions(self, ex: Exception) -> Exception: | |||
| if isinstance(ex, ValueError): | |||
| return LLMBadRequestError(f"ChatGLM: {str(ex)}") | |||
| else: | |||
| return ex | |||
| @classmethod | |||
| def support_streaming(cls): | |||
| return False | |||
| @@ -0,0 +1,82 @@ | |||
| import decimal | |||
| from functools import wraps | |||
| from typing import List, Optional, Any | |||
| from langchain import HuggingFaceHub | |||
| from langchain.callbacks.manager import Callbacks | |||
| from langchain.llms import HuggingFaceEndpoint | |||
| from langchain.schema import LLMResult | |||
| from core.model_providers.error import LLMBadRequestError | |||
| from core.model_providers.models.llm.base import BaseLLM | |||
| from core.model_providers.models.entity.message import PromptMessage, MessageType | |||
| from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs | |||
| class HuggingfaceHubModel(BaseLLM): | |||
| model_mode: ModelMode = ModelMode.COMPLETION | |||
| def _init_client(self) -> Any: | |||
| provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs) | |||
| if self.credentials['huggingfacehub_api_type'] == 'inference_endpoints': | |||
| client = HuggingFaceEndpoint( | |||
| endpoint_url=self.credentials['huggingfacehub_endpoint_url'], | |||
| task='text2text-generation', | |||
| model_kwargs=provider_model_kwargs, | |||
| huggingfacehub_api_token=self.credentials['huggingfacehub_api_token'], | |||
| callbacks=self.callbacks, | |||
| ) | |||
| else: | |||
| client = HuggingFaceHub( | |||
| repo_id=self.name, | |||
| task=self.credentials['task_type'], | |||
| model_kwargs=provider_model_kwargs, | |||
| huggingfacehub_api_token=self.credentials['huggingfacehub_api_token'], | |||
| callbacks=self.callbacks, | |||
| ) | |||
| return client | |||
| def _run(self, messages: List[PromptMessage], | |||
| stop: Optional[List[str]] = None, | |||
| callbacks: Callbacks = None, | |||
| **kwargs) -> LLMResult: | |||
| """ | |||
| run predict by prompt messages and stop words. | |||
| :param messages: | |||
| :param stop: | |||
| :param callbacks: | |||
| :return: | |||
| """ | |||
| prompts = self._get_prompt_from_messages(messages) | |||
| return self._client.generate([prompts], stop, callbacks) | |||
| def get_num_tokens(self, messages: List[PromptMessage]) -> int: | |||
| """ | |||
| get num tokens of prompt messages. | |||
| :param messages: | |||
| :return: | |||
| """ | |||
| prompts = self._get_prompt_from_messages(messages) | |||
| return self._client.get_num_tokens(prompts) | |||
| def get_token_price(self, tokens: int, message_type: MessageType): | |||
| # not support calc price | |||
| return decimal.Decimal('0') | |||
| def get_currency(self): | |||
| return 'USD' | |||
| def _set_model_kwargs(self, model_kwargs: ModelKwargs): | |||
| provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs) | |||
| self.client.model_kwargs = provider_model_kwargs | |||
| def handle_exceptions(self, ex: Exception) -> Exception: | |||
| return LLMBadRequestError(f"Huggingface Hub: {str(ex)}") | |||
| @classmethod | |||
| def support_streaming(cls): | |||
| return False | |||
| @@ -0,0 +1,70 @@ | |||
| import decimal | |||
| from typing import List, Optional, Any | |||
| from langchain.callbacks.manager import Callbacks | |||
| from langchain.llms import Minimax | |||
| from langchain.schema import LLMResult | |||
| from core.model_providers.error import LLMBadRequestError | |||
| from core.model_providers.models.llm.base import BaseLLM | |||
| from core.model_providers.models.entity.message import PromptMessage, MessageType | |||
| from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs | |||
| class MinimaxModel(BaseLLM): | |||
| model_mode: ModelMode = ModelMode.COMPLETION | |||
| def _init_client(self) -> Any: | |||
| provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs) | |||
| return Minimax( | |||
| model=self.name, | |||
| model_kwargs={ | |||
| 'stream': False | |||
| }, | |||
| callbacks=self.callbacks, | |||
| **self.credentials, | |||
| **provider_model_kwargs | |||
| ) | |||
| def _run(self, messages: List[PromptMessage], | |||
| stop: Optional[List[str]] = None, | |||
| callbacks: Callbacks = None, | |||
| **kwargs) -> LLMResult: | |||
| """ | |||
| run predict by prompt messages and stop words. | |||
| :param messages: | |||
| :param stop: | |||
| :param callbacks: | |||
| :return: | |||
| """ | |||
| prompts = self._get_prompt_from_messages(messages) | |||
| return self._client.generate([prompts], stop, callbacks) | |||
| def get_num_tokens(self, messages: List[PromptMessage]) -> int: | |||
| """ | |||
| get num tokens of prompt messages. | |||
| :param messages: | |||
| :return: | |||
| """ | |||
| prompts = self._get_prompt_from_messages(messages) | |||
| return max(self._client.get_num_tokens(prompts), 0) | |||
| def get_token_price(self, tokens: int, message_type: MessageType): | |||
| return decimal.Decimal('0') | |||
| def get_currency(self): | |||
| return 'RMB' | |||
| def _set_model_kwargs(self, model_kwargs: ModelKwargs): | |||
| provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs) | |||
| for k, v in provider_model_kwargs.items(): | |||
| if hasattr(self.client, k): | |||
| setattr(self.client, k, v) | |||
| def handle_exceptions(self, ex: Exception) -> Exception: | |||
| if isinstance(ex, ValueError): | |||
| return LLMBadRequestError(f"Minimax: {str(ex)}") | |||
| else: | |||
| return ex | |||
| @@ -0,0 +1,219 @@ | |||
| import decimal | |||
| import logging | |||
| from typing import List, Optional, Any | |||
| import openai | |||
| from langchain.callbacks.manager import Callbacks | |||
| from langchain.schema import LLMResult | |||
| from core.model_providers.providers.base import BaseModelProvider | |||
| from core.third_party.langchain.llms.chat_open_ai import EnhanceChatOpenAI | |||
| from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, \ | |||
| LLMRateLimitError, LLMAuthorizationError, ModelCurrentlyNotSupportError | |||
| from core.third_party.langchain.llms.open_ai import EnhanceOpenAI | |||
| from core.model_providers.models.llm.base import BaseLLM | |||
| from core.model_providers.models.entity.message import PromptMessage, MessageType | |||
| from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs | |||
| from models.provider import ProviderType, ProviderQuotaType | |||
| COMPLETION_MODELS = [ | |||
| 'text-davinci-003', # 4,097 tokens | |||
| ] | |||
| CHAT_MODELS = [ | |||
| 'gpt-4', # 8,192 tokens | |||
| 'gpt-4-32k', # 32,768 tokens | |||
| 'gpt-3.5-turbo', # 4,096 tokens | |||
| 'gpt-3.5-turbo-16k', # 16,384 tokens | |||
| ] | |||
| MODEL_MAX_TOKENS = { | |||
| 'gpt-4': 8192, | |||
| 'gpt-4-32k': 32768, | |||
| 'gpt-3.5-turbo': 4096, | |||
| 'gpt-3.5-turbo-16k': 16384, | |||
| 'text-davinci-003': 4097, | |||
| } | |||
| class OpenAIModel(BaseLLM): | |||
| def __init__(self, model_provider: BaseModelProvider, | |||
| name: str, | |||
| model_kwargs: ModelKwargs, | |||
| streaming: bool = False, | |||
| callbacks: Callbacks = None): | |||
| if name in COMPLETION_MODELS: | |||
| self.model_mode = ModelMode.COMPLETION | |||
| else: | |||
| self.model_mode = ModelMode.CHAT | |||
| super().__init__(model_provider, name, model_kwargs, streaming, callbacks) | |||
| def _init_client(self) -> Any: | |||
| provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs) | |||
| if self.name in COMPLETION_MODELS: | |||
| client = EnhanceOpenAI( | |||
| model_name=self.name, | |||
| streaming=self.streaming, | |||
| callbacks=self.callbacks, | |||
| request_timeout=60, | |||
| **self.credentials, | |||
| **provider_model_kwargs | |||
| ) | |||
| else: | |||
| # Fine-tuning is currently only available for the following base models: | |||
| # davinci, curie, babbage, and ada. | |||
| # This means that except for the fixed `completion` model, | |||
| # all other fine-tuned models are `completion` models. | |||
| extra_model_kwargs = { | |||
| 'top_p': provider_model_kwargs.get('top_p'), | |||
| 'frequency_penalty': provider_model_kwargs.get('frequency_penalty'), | |||
| 'presence_penalty': provider_model_kwargs.get('presence_penalty'), | |||
| } | |||
| client = EnhanceChatOpenAI( | |||
| model_name=self.name, | |||
| temperature=provider_model_kwargs.get('temperature'), | |||
| max_tokens=provider_model_kwargs.get('max_tokens'), | |||
| model_kwargs=extra_model_kwargs, | |||
| streaming=self.streaming, | |||
| callbacks=self.callbacks, | |||
| request_timeout=60, | |||
| **self.credentials | |||
| ) | |||
| return client | |||
| def _run(self, messages: List[PromptMessage], | |||
| stop: Optional[List[str]] = None, | |||
| callbacks: Callbacks = None, | |||
| **kwargs) -> LLMResult: | |||
| """ | |||
| run predict by prompt messages and stop words. | |||
| :param messages: | |||
| :param stop: | |||
| :param callbacks: | |||
| :return: | |||
| """ | |||
| if self.name == 'gpt-4' \ | |||
| and self.model_provider.provider.provider_type == ProviderType.SYSTEM.value \ | |||
| and self.model_provider.provider.quota_type == ProviderQuotaType.TRIAL.value: | |||
| raise ModelCurrentlyNotSupportError("Dify Hosted OpenAI GPT-4 currently not support.") | |||
| prompts = self._get_prompt_from_messages(messages) | |||
| return self._client.generate([prompts], stop, callbacks) | |||
| def get_num_tokens(self, messages: List[PromptMessage]) -> int: | |||
| """ | |||
| get num tokens of prompt messages. | |||
| :param messages: | |||
| :return: | |||
| """ | |||
| prompts = self._get_prompt_from_messages(messages) | |||
| if isinstance(prompts, str): | |||
| return self._client.get_num_tokens(prompts) | |||
| else: | |||
| return max(self._client.get_num_tokens_from_messages(prompts) - len(prompts), 0) | |||
| def get_token_price(self, tokens: int, message_type: MessageType): | |||
| model_unit_prices = { | |||
| 'gpt-4': { | |||
| 'prompt': decimal.Decimal('0.03'), | |||
| 'completion': decimal.Decimal('0.06'), | |||
| }, | |||
| 'gpt-4-32k': { | |||
| 'prompt': decimal.Decimal('0.06'), | |||
| 'completion': decimal.Decimal('0.12') | |||
| }, | |||
| 'gpt-3.5-turbo': { | |||
| 'prompt': decimal.Decimal('0.0015'), | |||
| 'completion': decimal.Decimal('0.002') | |||
| }, | |||
| 'gpt-3.5-turbo-16k': { | |||
| 'prompt': decimal.Decimal('0.003'), | |||
| 'completion': decimal.Decimal('0.004') | |||
| }, | |||
| 'text-davinci-003': { | |||
| 'prompt': decimal.Decimal('0.02'), | |||
| 'completion': decimal.Decimal('0.02') | |||
| }, | |||
| } | |||
| if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM: | |||
| unit_price = model_unit_prices[self.name]['prompt'] | |||
| else: | |||
| unit_price = model_unit_prices[self.name]['completion'] | |||
| tokens_per_1k = (decimal.Decimal(tokens) / 1000).quantize(decimal.Decimal('0.001'), | |||
| rounding=decimal.ROUND_HALF_UP) | |||
| total_price = tokens_per_1k * unit_price | |||
| return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP) | |||
| def get_currency(self): | |||
| return 'USD' | |||
| def _set_model_kwargs(self, model_kwargs: ModelKwargs): | |||
| provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs) | |||
| if self.name in COMPLETION_MODELS: | |||
| for k, v in provider_model_kwargs.items(): | |||
| if hasattr(self.client, k): | |||
| setattr(self.client, k, v) | |||
| else: | |||
| extra_model_kwargs = { | |||
| 'top_p': provider_model_kwargs.get('top_p'), | |||
| 'frequency_penalty': provider_model_kwargs.get('frequency_penalty'), | |||
| 'presence_penalty': provider_model_kwargs.get('presence_penalty'), | |||
| } | |||
| self.client.temperature = provider_model_kwargs.get('temperature') | |||
| self.client.max_tokens = provider_model_kwargs.get('max_tokens') | |||
| self.client.model_kwargs = extra_model_kwargs | |||
| def handle_exceptions(self, ex: Exception) -> Exception: | |||
| if isinstance(ex, openai.error.InvalidRequestError): | |||
| logging.warning("Invalid request to OpenAI API.") | |||
| return LLMBadRequestError(str(ex)) | |||
| elif isinstance(ex, openai.error.APIConnectionError): | |||
| logging.warning("Failed to connect to OpenAI API.") | |||
| return LLMAPIConnectionError(ex.__class__.__name__ + ":" + str(ex)) | |||
| elif isinstance(ex, (openai.error.APIError, openai.error.ServiceUnavailableError, openai.error.Timeout)): | |||
| logging.warning("OpenAI service unavailable.") | |||
| return LLMAPIUnavailableError(ex.__class__.__name__ + ":" + str(ex)) | |||
| elif isinstance(ex, openai.error.RateLimitError): | |||
| return LLMRateLimitError(str(ex)) | |||
| elif isinstance(ex, openai.error.AuthenticationError): | |||
| raise LLMAuthorizationError(str(ex)) | |||
| elif isinstance(ex, openai.error.OpenAIError): | |||
| return LLMBadRequestError(ex.__class__.__name__ + ":" + str(ex)) | |||
| else: | |||
| return ex | |||
| @classmethod | |||
| def support_streaming(cls): | |||
| return True | |||
| # def is_model_valid_or_raise(self): | |||
| # """ | |||
| # check is a valid model. | |||
| # | |||
| # :return: | |||
| # """ | |||
| # credentials = self._model_provider.get_credentials() | |||
| # | |||
| # try: | |||
| # result = openai.Model.retrieve( | |||
| # id=self.name, | |||
| # api_key=credentials.get('openai_api_key'), | |||
| # request_timeout=60 | |||
| # ) | |||
| # | |||
| # if 'id' not in result or result['id'] != self.name: | |||
| # raise LLMNotExistsError(f"OpenAI Model {self.name} not exists.") | |||
| # except openai.error.OpenAIError as e: | |||
| # raise LLMNotExistsError(f"OpenAI Model {self.name} not exists, cause: {e.__class__.__name__}:{str(e)}") | |||
| # except Exception as e: | |||
| # logging.exception("OpenAI Model retrieve failed.") | |||
| # raise e | |||
| @@ -0,0 +1,103 @@ | |||
| import decimal | |||
| from functools import wraps | |||
| from typing import List, Optional, Any | |||
| from langchain.callbacks.manager import Callbacks | |||
| from langchain.schema import LLMResult, get_buffer_string | |||
| from replicate.exceptions import ReplicateError, ModelError | |||
| from core.model_providers.providers.base import BaseModelProvider | |||
| from core.model_providers.error import LLMBadRequestError | |||
| from core.third_party.langchain.llms.replicate_llm import EnhanceReplicate | |||
| from core.model_providers.models.llm.base import BaseLLM | |||
| from core.model_providers.models.entity.message import PromptMessage, MessageType | |||
| from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs | |||
| class ReplicateModel(BaseLLM): | |||
| def __init__(self, model_provider: BaseModelProvider, | |||
| name: str, | |||
| model_kwargs: ModelKwargs, | |||
| streaming: bool = False, | |||
| callbacks: Callbacks = None): | |||
| self.model_mode = ModelMode.CHAT if name.endswith('-chat') else ModelMode.COMPLETION | |||
| super().__init__(model_provider, name, model_kwargs, streaming, callbacks) | |||
| def _init_client(self) -> Any: | |||
| provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs) | |||
| return EnhanceReplicate( | |||
| model=self.name + ':' + self.credentials.get('model_version'), | |||
| input=provider_model_kwargs, | |||
| streaming=self.streaming, | |||
| replicate_api_token=self.credentials.get('replicate_api_token'), | |||
| callbacks=self.callbacks, | |||
| ) | |||
| def _run(self, messages: List[PromptMessage], | |||
| stop: Optional[List[str]] = None, | |||
| callbacks: Callbacks = None, | |||
| **kwargs) -> LLMResult: | |||
| """ | |||
| run predict by prompt messages and stop words. | |||
| :param messages: | |||
| :param stop: | |||
| :param callbacks: | |||
| :return: | |||
| """ | |||
| prompts = self._get_prompt_from_messages(messages) | |||
| extra_kwargs = {} | |||
| if isinstance(prompts, list): | |||
| system_messages = [message for message in messages if message.type == 'system'] | |||
| if system_messages: | |||
| system_message = system_messages[0] | |||
| extra_kwargs['system_prompt'] = system_message.content | |||
| prompts = [message for message in messages if message.type != 'system'] | |||
| prompts = get_buffer_string(prompts) | |||
| # The maximum length the generated tokens can have. | |||
| # Corresponds to the length of the input prompt + max_new_tokens. | |||
| if 'max_length' in self._client.input: | |||
| self._client.input['max_length'] = min( | |||
| self._client.input['max_length'] + self.get_num_tokens(messages), | |||
| self.model_rules.max_tokens.max | |||
| ) | |||
| return self._client.generate([prompts], stop, callbacks, **extra_kwargs) | |||
| def get_num_tokens(self, messages: List[PromptMessage]) -> int: | |||
| """ | |||
| get num tokens of prompt messages. | |||
| :param messages: | |||
| :return: | |||
| """ | |||
| prompts = self._get_prompt_from_messages(messages) | |||
| if isinstance(prompts, list): | |||
| prompts = get_buffer_string(prompts) | |||
| return self._client.get_num_tokens(prompts) | |||
| def get_token_price(self, tokens: int, message_type: MessageType): | |||
| # replicate only pay for prediction seconds | |||
| return decimal.Decimal('0') | |||
| def get_currency(self): | |||
| return 'USD' | |||
| def _set_model_kwargs(self, model_kwargs: ModelKwargs): | |||
| provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs) | |||
| self.client.input = provider_model_kwargs | |||
| def handle_exceptions(self, ex: Exception) -> Exception: | |||
| if isinstance(ex, (ModelError, ReplicateError)): | |||
| return LLMBadRequestError(f"Replicate: {str(ex)}") | |||
| else: | |||
| return ex | |||
| @classmethod | |||
| def support_streaming(cls): | |||
| return True | |||