| @@ -18,7 +18,8 @@ from models.model import Account | |||
| import secrets | |||
| import base64 | |||
| from models.provider import Provider | |||
| from models.provider import Provider, ProviderName | |||
| from services.provider_service import ProviderService | |||
| @click.command('reset-password', help='Reset the account password.') | |||
| @@ -193,9 +194,40 @@ def recreate_all_dataset_indexes(): | |||
| click.echo(click.style('Congratulations! Recreate {} dataset indexes.'.format(recreate_count), fg='green')) | |||
| @click.command('sync-anthropic-hosted-providers', help='Sync anthropic hosted providers.') | |||
| def sync_anthropic_hosted_providers(): | |||
| click.echo(click.style('Start sync anthropic hosted providers.', fg='green')) | |||
| count = 0 | |||
| page = 1 | |||
| while True: | |||
| try: | |||
| tenants = db.session.query(Tenant).order_by(Tenant.created_at.desc()).paginate(page=page, per_page=50) | |||
| except NotFound: | |||
| break | |||
| page += 1 | |||
| for tenant in tenants: | |||
| try: | |||
| click.echo('Syncing tenant anthropic hosted provider: {}'.format(tenant.id)) | |||
| ProviderService.create_system_provider( | |||
| tenant, | |||
| ProviderName.ANTHROPIC.value, | |||
| current_app.config['ANTHROPIC_HOSTED_QUOTA_LIMIT'], | |||
| True | |||
| ) | |||
| count += 1 | |||
| except Exception as e: | |||
| click.echo(click.style('Sync tenant anthropic hosted provider error: {} {}'.format(e.__class__.__name__, str(e)), fg='red')) | |||
| continue | |||
| click.echo(click.style('Congratulations! Synced {} anthropic hosted providers.'.format(count), fg='green')) | |||
| def register_commands(app): | |||
| app.cli.add_command(reset_password) | |||
| app.cli.add_command(reset_email) | |||
| app.cli.add_command(generate_invitation_codes) | |||
| app.cli.add_command(reset_encrypt_key_pair) | |||
| app.cli.add_command(recreate_all_dataset_indexes) | |||
| app.cli.add_command(sync_anthropic_hosted_providers) | |||
| @@ -51,6 +51,8 @@ DEFAULTS = { | |||
| 'LOG_LEVEL': 'INFO', | |||
| 'DISABLE_PROVIDER_CONFIG_VALIDATION': 'False', | |||
| 'DEFAULT_LLM_PROVIDER': 'openai', | |||
| 'OPENAI_HOSTED_QUOTA_LIMIT': 200, | |||
| 'ANTHROPIC_HOSTED_QUOTA_LIMIT': 1000, | |||
| 'TENANT_DOCUMENT_COUNT': 100 | |||
| } | |||
| @@ -192,6 +194,10 @@ class Config: | |||
| # hosted provider credentials | |||
| self.OPENAI_API_KEY = get_env('OPENAI_API_KEY') | |||
| self.ANTHROPIC_API_KEY = get_env('ANTHROPIC_API_KEY') | |||
| self.OPENAI_HOSTED_QUOTA_LIMIT = get_env('OPENAI_HOSTED_QUOTA_LIMIT') | |||
| self.ANTHROPIC_HOSTED_QUOTA_LIMIT = get_env('ANTHROPIC_HOSTED_QUOTA_LIMIT') | |||
| # By default it is False | |||
| # You could disable it for compatibility with certain OpenAPI providers | |||
| @@ -50,8 +50,8 @@ class ChatMessageAudioApi(Resource): | |||
| raise UnsupportedAudioTypeError() | |||
| except ProviderNotSupportSpeechToTextServiceError: | |||
| raise ProviderNotSupportSpeechToTextError() | |||
| except ProviderTokenNotInitError: | |||
| raise ProviderNotInitializeError() | |||
| except ProviderTokenNotInitError as ex: | |||
| raise ProviderNotInitializeError(ex.description) | |||
| except QuotaExceededError: | |||
| raise ProviderQuotaExceededError() | |||
| except ModelCurrentlyNotSupportError: | |||
| @@ -63,8 +63,8 @@ class CompletionMessageApi(Resource): | |||
| except services.errors.app_model_config.AppModelConfigBrokenError: | |||
| logging.exception("App model config broken.") | |||
| raise AppUnavailableError() | |||
| except ProviderTokenNotInitError: | |||
| raise ProviderNotInitializeError() | |||
| except ProviderTokenNotInitError as ex: | |||
| raise ProviderNotInitializeError(ex.description) | |||
| except QuotaExceededError: | |||
| raise ProviderQuotaExceededError() | |||
| except ModelCurrentlyNotSupportError: | |||
| @@ -133,8 +133,8 @@ class ChatMessageApi(Resource): | |||
| except services.errors.app_model_config.AppModelConfigBrokenError: | |||
| logging.exception("App model config broken.") | |||
| raise AppUnavailableError() | |||
| except ProviderTokenNotInitError: | |||
| raise ProviderNotInitializeError() | |||
| except ProviderTokenNotInitError as ex: | |||
| raise ProviderNotInitializeError(ex.description) | |||
| except QuotaExceededError: | |||
| raise ProviderQuotaExceededError() | |||
| except ModelCurrentlyNotSupportError: | |||
| @@ -164,8 +164,8 @@ def compact_response(response: Union[dict | Generator]) -> Response: | |||
| except services.errors.app_model_config.AppModelConfigBrokenError: | |||
| logging.exception("App model config broken.") | |||
| yield "data: " + json.dumps(api.handle_error(AppUnavailableError()).get_json()) + "\n\n" | |||
| except ProviderTokenNotInitError: | |||
| yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError()).get_json()) + "\n\n" | |||
| except ProviderTokenNotInitError as ex: | |||
| yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError(ex.description)).get_json()) + "\n\n" | |||
| except QuotaExceededError: | |||
| yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n" | |||
| except ModelCurrentlyNotSupportError: | |||
| @@ -16,7 +16,7 @@ class ProviderNotInitializeError(BaseHTTPException): | |||
| class ProviderQuotaExceededError(BaseHTTPException): | |||
| error_code = 'provider_quota_exceeded' | |||
| description = "Your quota for Dify Hosted OpenAI has been exhausted. " \ | |||
| description = "Your quota for Dify Hosted Model Provider has been exhausted. " \ | |||
| "Please go to Settings -> Model Provider to complete your own provider credentials." | |||
| code = 400 | |||
| @@ -27,8 +27,8 @@ class IntroductionGenerateApi(Resource): | |||
| account.current_tenant_id, | |||
| args['prompt_template'] | |||
| ) | |||
| except ProviderTokenNotInitError: | |||
| raise ProviderNotInitializeError() | |||
| except ProviderTokenNotInitError as ex: | |||
| raise ProviderNotInitializeError(ex.description) | |||
| except QuotaExceededError: | |||
| raise ProviderQuotaExceededError() | |||
| except ModelCurrentlyNotSupportError: | |||
| @@ -58,8 +58,8 @@ class RuleGenerateApi(Resource): | |||
| args['audiences'], | |||
| args['hoping_to_solve'] | |||
| ) | |||
| except ProviderTokenNotInitError: | |||
| raise ProviderNotInitializeError() | |||
| except ProviderTokenNotInitError as ex: | |||
| raise ProviderNotInitializeError(ex.description) | |||
| except QuotaExceededError: | |||
| raise ProviderQuotaExceededError() | |||
| except ModelCurrentlyNotSupportError: | |||
| @@ -269,8 +269,8 @@ class MessageMoreLikeThisApi(Resource): | |||
| raise NotFound("Message Not Exists.") | |||
| except MoreLikeThisDisabledError: | |||
| raise AppMoreLikeThisDisabledError() | |||
| except ProviderTokenNotInitError: | |||
| raise ProviderNotInitializeError() | |||
| except ProviderTokenNotInitError as ex: | |||
| raise ProviderNotInitializeError(ex.description) | |||
| except QuotaExceededError: | |||
| raise ProviderQuotaExceededError() | |||
| except ModelCurrentlyNotSupportError: | |||
| @@ -297,8 +297,8 @@ def compact_response(response: Union[dict | Generator]) -> Response: | |||
| yield "data: " + json.dumps(api.handle_error(NotFound("Message Not Exists.")).get_json()) + "\n\n" | |||
| except MoreLikeThisDisabledError: | |||
| yield "data: " + json.dumps(api.handle_error(AppMoreLikeThisDisabledError()).get_json()) + "\n\n" | |||
| except ProviderTokenNotInitError: | |||
| yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError()).get_json()) + "\n\n" | |||
| except ProviderTokenNotInitError as ex: | |||
| yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError(ex.description)).get_json()) + "\n\n" | |||
| except QuotaExceededError: | |||
| yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n" | |||
| except ModelCurrentlyNotSupportError: | |||
| @@ -339,8 +339,8 @@ class MessageSuggestedQuestionApi(Resource): | |||
| raise NotFound("Message not found") | |||
| except ConversationNotExistsError: | |||
| raise NotFound("Conversation not found") | |||
| except ProviderTokenNotInitError: | |||
| raise ProviderNotInitializeError() | |||
| except ProviderTokenNotInitError as ex: | |||
| raise ProviderNotInitializeError(ex.description) | |||
| except QuotaExceededError: | |||
| raise ProviderQuotaExceededError() | |||
| except ModelCurrentlyNotSupportError: | |||
| @@ -279,8 +279,8 @@ class DatasetDocumentListApi(Resource): | |||
| try: | |||
| documents, batch = DocumentService.save_document_with_dataset_id(dataset, args, current_user) | |||
| except ProviderTokenNotInitError: | |||
| raise ProviderNotInitializeError() | |||
| except ProviderTokenNotInitError as ex: | |||
| raise ProviderNotInitializeError(ex.description) | |||
| except QuotaExceededError: | |||
| raise ProviderQuotaExceededError() | |||
| except ModelCurrentlyNotSupportError: | |||
| @@ -324,8 +324,8 @@ class DatasetInitApi(Resource): | |||
| document_data=args, | |||
| account=current_user | |||
| ) | |||
| except ProviderTokenNotInitError: | |||
| raise ProviderNotInitializeError() | |||
| except ProviderTokenNotInitError as ex: | |||
| raise ProviderNotInitializeError(ex.description) | |||
| except QuotaExceededError: | |||
| raise ProviderQuotaExceededError() | |||
| except ModelCurrentlyNotSupportError: | |||
| @@ -95,8 +95,8 @@ class HitTestingApi(Resource): | |||
| return {"query": response['query'], 'records': marshal(response['records'], hit_testing_record_fields)} | |||
| except services.errors.index.IndexNotInitializedError: | |||
| raise DatasetNotInitializedError() | |||
| except ProviderTokenNotInitError: | |||
| raise ProviderNotInitializeError() | |||
| except ProviderTokenNotInitError as ex: | |||
| raise ProviderNotInitializeError(ex.description) | |||
| except QuotaExceededError: | |||
| raise ProviderQuotaExceededError() | |||
| except ModelCurrentlyNotSupportError: | |||
| @@ -47,8 +47,8 @@ class ChatAudioApi(InstalledAppResource): | |||
| raise UnsupportedAudioTypeError() | |||
| except ProviderNotSupportSpeechToTextServiceError: | |||
| raise ProviderNotSupportSpeechToTextError() | |||
| except ProviderTokenNotInitError: | |||
| raise ProviderNotInitializeError() | |||
| except ProviderTokenNotInitError as ex: | |||
| raise ProviderNotInitializeError(ex.description) | |||
| except QuotaExceededError: | |||
| raise ProviderQuotaExceededError() | |||
| except ModelCurrentlyNotSupportError: | |||
| @@ -54,8 +54,8 @@ class CompletionApi(InstalledAppResource): | |||
| except services.errors.app_model_config.AppModelConfigBrokenError: | |||
| logging.exception("App model config broken.") | |||
| raise AppUnavailableError() | |||
| except ProviderTokenNotInitError: | |||
| raise ProviderNotInitializeError() | |||
| except ProviderTokenNotInitError as ex: | |||
| raise ProviderNotInitializeError(ex.description) | |||
| except QuotaExceededError: | |||
| raise ProviderQuotaExceededError() | |||
| except ModelCurrentlyNotSupportError: | |||
| @@ -113,8 +113,8 @@ class ChatApi(InstalledAppResource): | |||
| except services.errors.app_model_config.AppModelConfigBrokenError: | |||
| logging.exception("App model config broken.") | |||
| raise AppUnavailableError() | |||
| except ProviderTokenNotInitError: | |||
| raise ProviderNotInitializeError() | |||
| except ProviderTokenNotInitError as ex: | |||
| raise ProviderNotInitializeError(ex.description) | |||
| except QuotaExceededError: | |||
| raise ProviderQuotaExceededError() | |||
| except ModelCurrentlyNotSupportError: | |||
| @@ -155,8 +155,8 @@ def compact_response(response: Union[dict | Generator]) -> Response: | |||
| except services.errors.app_model_config.AppModelConfigBrokenError: | |||
| logging.exception("App model config broken.") | |||
| yield "data: " + json.dumps(api.handle_error(AppUnavailableError()).get_json()) + "\n\n" | |||
| except ProviderTokenNotInitError: | |||
| yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError()).get_json()) + "\n\n" | |||
| except ProviderTokenNotInitError as ex: | |||
| yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError(ex.description)).get_json()) + "\n\n" | |||
| except QuotaExceededError: | |||
| yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n" | |||
| except ModelCurrentlyNotSupportError: | |||
| @@ -107,8 +107,8 @@ class MessageMoreLikeThisApi(InstalledAppResource): | |||
| raise NotFound("Message Not Exists.") | |||
| except MoreLikeThisDisabledError: | |||
| raise AppMoreLikeThisDisabledError() | |||
| except ProviderTokenNotInitError: | |||
| raise ProviderNotInitializeError() | |||
| except ProviderTokenNotInitError as ex: | |||
| raise ProviderNotInitializeError(ex.description) | |||
| except QuotaExceededError: | |||
| raise ProviderQuotaExceededError() | |||
| except ModelCurrentlyNotSupportError: | |||
| @@ -135,8 +135,8 @@ def compact_response(response: Union[dict | Generator]) -> Response: | |||
| yield "data: " + json.dumps(api.handle_error(NotFound("Message Not Exists.")).get_json()) + "\n\n" | |||
| except MoreLikeThisDisabledError: | |||
| yield "data: " + json.dumps(api.handle_error(AppMoreLikeThisDisabledError()).get_json()) + "\n\n" | |||
| except ProviderTokenNotInitError: | |||
| yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError()).get_json()) + "\n\n" | |||
| except ProviderTokenNotInitError as ex: | |||
| yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError(ex.description)).get_json()) + "\n\n" | |||
| except QuotaExceededError: | |||
| yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n" | |||
| except ModelCurrentlyNotSupportError: | |||
| @@ -174,8 +174,8 @@ class MessageSuggestedQuestionApi(InstalledAppResource): | |||
| raise NotFound("Conversation not found") | |||
| except SuggestedQuestionsAfterAnswerDisabledError: | |||
| raise AppSuggestedQuestionsAfterAnswerDisabledError() | |||
| except ProviderTokenNotInitError: | |||
| raise ProviderNotInitializeError() | |||
| except ProviderTokenNotInitError as ex: | |||
| raise ProviderNotInitializeError(ex.description) | |||
| except QuotaExceededError: | |||
| raise ProviderQuotaExceededError() | |||
| except ModelCurrentlyNotSupportError: | |||
| @@ -3,6 +3,7 @@ import base64 | |||
| import json | |||
| import logging | |||
| from flask import current_app | |||
| from flask_login import login_required, current_user | |||
| from flask_restful import Resource, reqparse, abort | |||
| from werkzeug.exceptions import Forbidden | |||
| @@ -34,7 +35,7 @@ class ProviderListApi(Resource): | |||
| plaintext, the rest is replaced by * and the last two bits are displayed in plaintext | |||
| """ | |||
| ProviderService.init_supported_provider(current_user.current_tenant, "cloud") | |||
| ProviderService.init_supported_provider(current_user.current_tenant) | |||
| providers = Provider.query.filter_by(tenant_id=tenant_id).all() | |||
| provider_list = [ | |||
| @@ -50,7 +51,8 @@ class ProviderListApi(Resource): | |||
| 'quota_used': p.quota_used | |||
| } if p.provider_type == ProviderType.SYSTEM.value else {}), | |||
| 'token': ProviderService.get_obfuscated_api_key(current_user.current_tenant, | |||
| ProviderName(p.provider_name)) | |||
| ProviderName(p.provider_name), only_custom=True) | |||
| if p.provider_type == ProviderType.CUSTOM.value else None | |||
| } | |||
| for p in providers | |||
| ] | |||
| @@ -121,9 +123,10 @@ class ProviderTokenApi(Resource): | |||
| is_valid=token_is_valid) | |||
| db.session.add(provider_model) | |||
| if provider_model.is_valid: | |||
| if provider in [ProviderName.OPENAI.value, ProviderName.AZURE_OPENAI.value] and provider_model.is_valid: | |||
| other_providers = db.session.query(Provider).filter( | |||
| Provider.tenant_id == tenant.id, | |||
| Provider.provider_name.in_([ProviderName.OPENAI.value, ProviderName.AZURE_OPENAI.value]), | |||
| Provider.provider_name != provider, | |||
| Provider.provider_type == ProviderType.CUSTOM.value | |||
| ).all() | |||
| @@ -133,7 +136,7 @@ class ProviderTokenApi(Resource): | |||
| db.session.commit() | |||
| if provider in [ProviderName.ANTHROPIC.value, ProviderName.AZURE_OPENAI.value, ProviderName.COHERE.value, | |||
| if provider in [ProviderName.AZURE_OPENAI.value, ProviderName.COHERE.value, | |||
| ProviderName.HUGGINGFACEHUB.value]: | |||
| return {'result': 'success', 'warning': 'MOCK: This provider is not supported yet.'}, 201 | |||
| @@ -157,7 +160,7 @@ class ProviderTokenValidateApi(Resource): | |||
| args = parser.parse_args() | |||
| # todo: remove this when the provider is supported | |||
| if provider in [ProviderName.ANTHROPIC.value, ProviderName.COHERE.value, | |||
| if provider in [ProviderName.COHERE.value, | |||
| ProviderName.HUGGINGFACEHUB.value]: | |||
| return {'result': 'success', 'warning': 'MOCK: This provider is not supported yet.'} | |||
| @@ -203,7 +206,19 @@ class ProviderSystemApi(Resource): | |||
| provider_model.is_valid = args['is_enabled'] | |||
| db.session.commit() | |||
| elif not provider_model: | |||
| ProviderService.create_system_provider(tenant, provider, args['is_enabled']) | |||
| if provider == ProviderName.OPENAI.value: | |||
| quota_limit = current_app.config['OPENAI_HOSTED_QUOTA_LIMIT'] | |||
| elif provider == ProviderName.ANTHROPIC.value: | |||
| quota_limit = current_app.config['ANTHROPIC_HOSTED_QUOTA_LIMIT'] | |||
| else: | |||
| quota_limit = 0 | |||
| ProviderService.create_system_provider( | |||
| tenant, | |||
| provider, | |||
| quota_limit, | |||
| args['is_enabled'] | |||
| ) | |||
| else: | |||
| abort(403) | |||
| @@ -43,8 +43,8 @@ class AudioApi(AppApiResource): | |||
| raise UnsupportedAudioTypeError() | |||
| except ProviderNotSupportSpeechToTextServiceError: | |||
| raise ProviderNotSupportSpeechToTextError() | |||
| except ProviderTokenNotInitError: | |||
| raise ProviderNotInitializeError() | |||
| except ProviderTokenNotInitError as ex: | |||
| raise ProviderNotInitializeError(ex.description) | |||
| except QuotaExceededError: | |||
| raise ProviderQuotaExceededError() | |||
| except ModelCurrentlyNotSupportError: | |||
| @@ -54,8 +54,8 @@ class CompletionApi(AppApiResource): | |||
| except services.errors.app_model_config.AppModelConfigBrokenError: | |||
| logging.exception("App model config broken.") | |||
| raise AppUnavailableError() | |||
| except ProviderTokenNotInitError: | |||
| raise ProviderNotInitializeError() | |||
| except ProviderTokenNotInitError as ex: | |||
| raise ProviderNotInitializeError(ex.description) | |||
| except QuotaExceededError: | |||
| raise ProviderQuotaExceededError() | |||
| except ModelCurrentlyNotSupportError: | |||
| @@ -115,8 +115,8 @@ class ChatApi(AppApiResource): | |||
| except services.errors.app_model_config.AppModelConfigBrokenError: | |||
| logging.exception("App model config broken.") | |||
| raise AppUnavailableError() | |||
| except ProviderTokenNotInitError: | |||
| raise ProviderNotInitializeError() | |||
| except ProviderTokenNotInitError as ex: | |||
| raise ProviderNotInitializeError(ex.description) | |||
| except QuotaExceededError: | |||
| raise ProviderQuotaExceededError() | |||
| except ModelCurrentlyNotSupportError: | |||
| @@ -156,8 +156,8 @@ def compact_response(response: Union[dict | Generator]) -> Response: | |||
| except services.errors.app_model_config.AppModelConfigBrokenError: | |||
| logging.exception("App model config broken.") | |||
| yield "data: " + json.dumps(api.handle_error(AppUnavailableError()).get_json()) + "\n\n" | |||
| except ProviderTokenNotInitError: | |||
| yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError()).get_json()) + "\n\n" | |||
| except ProviderTokenNotInitError as ex: | |||
| yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError(ex.description)).get_json()) + "\n\n" | |||
| except QuotaExceededError: | |||
| yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n" | |||
| except ModelCurrentlyNotSupportError: | |||
| @@ -85,8 +85,8 @@ class DocumentListApi(DatasetApiResource): | |||
| dataset_process_rule=dataset.latest_process_rule, | |||
| created_from='api' | |||
| ) | |||
| except ProviderTokenNotInitError: | |||
| raise ProviderNotInitializeError() | |||
| except ProviderTokenNotInitError as ex: | |||
| raise ProviderNotInitializeError(ex.description) | |||
| document = documents[0] | |||
| if doc_type and doc_metadata: | |||
| metadata_schema = DocumentService.DOCUMENT_METADATA_SCHEMA[doc_type] | |||
| @@ -45,8 +45,8 @@ class AudioApi(WebApiResource): | |||
| raise UnsupportedAudioTypeError() | |||
| except ProviderNotSupportSpeechToTextServiceError: | |||
| raise ProviderNotSupportSpeechToTextError() | |||
| except ProviderTokenNotInitError: | |||
| raise ProviderNotInitializeError() | |||
| except ProviderTokenNotInitError as ex: | |||
| raise ProviderNotInitializeError(ex.description) | |||
| except QuotaExceededError: | |||
| raise ProviderQuotaExceededError() | |||
| except ModelCurrentlyNotSupportError: | |||
| @@ -52,8 +52,8 @@ class CompletionApi(WebApiResource): | |||
| except services.errors.app_model_config.AppModelConfigBrokenError: | |||
| logging.exception("App model config broken.") | |||
| raise AppUnavailableError() | |||
| except ProviderTokenNotInitError: | |||
| raise ProviderNotInitializeError() | |||
| except ProviderTokenNotInitError as ex: | |||
| raise ProviderNotInitializeError(ex.description) | |||
| except QuotaExceededError: | |||
| raise ProviderQuotaExceededError() | |||
| except ModelCurrentlyNotSupportError: | |||
| @@ -109,8 +109,8 @@ class ChatApi(WebApiResource): | |||
| except services.errors.app_model_config.AppModelConfigBrokenError: | |||
| logging.exception("App model config broken.") | |||
| raise AppUnavailableError() | |||
| except ProviderTokenNotInitError: | |||
| raise ProviderNotInitializeError() | |||
| except ProviderTokenNotInitError as ex: | |||
| raise ProviderNotInitializeError(ex.description) | |||
| except QuotaExceededError: | |||
| raise ProviderQuotaExceededError() | |||
| except ModelCurrentlyNotSupportError: | |||
| @@ -150,8 +150,8 @@ def compact_response(response: Union[dict | Generator]) -> Response: | |||
| except services.errors.app_model_config.AppModelConfigBrokenError: | |||
| logging.exception("App model config broken.") | |||
| yield "data: " + json.dumps(api.handle_error(AppUnavailableError()).get_json()) + "\n\n" | |||
| except ProviderTokenNotInitError: | |||
| yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError()).get_json()) + "\n\n" | |||
| except ProviderTokenNotInitError as ex: | |||
| yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError(ex.description)).get_json()) + "\n\n" | |||
| except QuotaExceededError: | |||
| yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n" | |||
| except ModelCurrentlyNotSupportError: | |||
| @@ -101,8 +101,8 @@ class MessageMoreLikeThisApi(WebApiResource): | |||
| raise NotFound("Message Not Exists.") | |||
| except MoreLikeThisDisabledError: | |||
| raise AppMoreLikeThisDisabledError() | |||
| except ProviderTokenNotInitError: | |||
| raise ProviderNotInitializeError() | |||
| except ProviderTokenNotInitError as ex: | |||
| raise ProviderNotInitializeError(ex.description) | |||
| except QuotaExceededError: | |||
| raise ProviderQuotaExceededError() | |||
| except ModelCurrentlyNotSupportError: | |||
| @@ -129,8 +129,8 @@ def compact_response(response: Union[dict | Generator]) -> Response: | |||
| yield "data: " + json.dumps(api.handle_error(NotFound("Message Not Exists.")).get_json()) + "\n\n" | |||
| except MoreLikeThisDisabledError: | |||
| yield "data: " + json.dumps(api.handle_error(AppMoreLikeThisDisabledError()).get_json()) + "\n\n" | |||
| except ProviderTokenNotInitError: | |||
| yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError()).get_json()) + "\n\n" | |||
| except ProviderTokenNotInitError as ex: | |||
| yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError(ex.description)).get_json()) + "\n\n" | |||
| except QuotaExceededError: | |||
| yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n" | |||
| except ModelCurrentlyNotSupportError: | |||
| @@ -167,8 +167,8 @@ class MessageSuggestedQuestionApi(WebApiResource): | |||
| raise NotFound("Conversation not found") | |||
| except SuggestedQuestionsAfterAnswerDisabledError: | |||
| raise AppSuggestedQuestionsAfterAnswerDisabledError() | |||
| except ProviderTokenNotInitError: | |||
| raise ProviderNotInitializeError() | |||
| except ProviderTokenNotInitError as ex: | |||
| raise ProviderNotInitializeError(ex.description) | |||
| except QuotaExceededError: | |||
| raise ProviderQuotaExceededError() | |||
| except ModelCurrentlyNotSupportError: | |||
| @@ -13,8 +13,13 @@ class HostedOpenAICredential(BaseModel): | |||
| api_key: str | |||
| class HostedAnthropicCredential(BaseModel): | |||
| api_key: str | |||
| class HostedLLMCredentials(BaseModel): | |||
| openai: Optional[HostedOpenAICredential] = None | |||
| anthropic: Optional[HostedAnthropicCredential] = None | |||
| hosted_llm_credentials = HostedLLMCredentials() | |||
| @@ -26,3 +31,6 @@ def init_app(app: Flask): | |||
| if app.config.get("OPENAI_API_KEY"): | |||
| hosted_llm_credentials.openai = HostedOpenAICredential(api_key=app.config.get("OPENAI_API_KEY")) | |||
| if app.config.get("ANTHROPIC_API_KEY"): | |||
| hosted_llm_credentials.anthropic = HostedAnthropicCredential(api_key=app.config.get("ANTHROPIC_API_KEY")) | |||
| @@ -48,7 +48,7 @@ class LLMCallbackHandler(BaseCallbackHandler): | |||
| }) | |||
| self.llm_message.prompt = real_prompts | |||
| self.llm_message.prompt_tokens = self.llm.get_messages_tokens(messages[0]) | |||
| self.llm_message.prompt_tokens = self.llm.get_num_tokens_from_messages(messages[0]) | |||
| def on_llm_start( | |||
| self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any | |||
| @@ -118,6 +118,7 @@ class Completion: | |||
| prompt, stop_words = cls.get_main_llm_prompt( | |||
| mode=mode, | |||
| llm=final_llm, | |||
| model=app_model_config.model_dict, | |||
| pre_prompt=app_model_config.pre_prompt, | |||
| query=query, | |||
| inputs=inputs, | |||
| @@ -129,6 +130,7 @@ class Completion: | |||
| cls.recale_llm_max_tokens( | |||
| final_llm=final_llm, | |||
| model=app_model_config.model_dict, | |||
| prompt=prompt, | |||
| mode=mode | |||
| ) | |||
| @@ -138,7 +140,8 @@ class Completion: | |||
| return response | |||
| @classmethod | |||
| def get_main_llm_prompt(cls, mode: str, llm: BaseLanguageModel, pre_prompt: str, query: str, inputs: dict, | |||
| def get_main_llm_prompt(cls, mode: str, llm: BaseLanguageModel, model: dict, | |||
| pre_prompt: str, query: str, inputs: dict, | |||
| chain_output: Optional[str], | |||
| memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory]) -> \ | |||
| Tuple[Union[str | List[BaseMessage]], Optional[List[str]]]: | |||
| @@ -151,10 +154,11 @@ class Completion: | |||
| if mode == 'completion': | |||
| prompt_template = JinjaPromptTemplate.from_template( | |||
| template=("""Use the following CONTEXT as your learned knowledge: | |||
| [CONTEXT] | |||
| template=("""Use the following context as your learned knowledge, inside <context></context> XML tags. | |||
| <context> | |||
| {{context}} | |||
| [END CONTEXT] | |||
| </context> | |||
| When answer to user: | |||
| - If you don't know, just say that you don't know. | |||
| @@ -204,10 +208,11 @@ And answer according to the language of the user's question. | |||
| if chain_output: | |||
| human_inputs['context'] = chain_output | |||
| human_message_prompt += """Use the following CONTEXT as your learned knowledge. | |||
| [CONTEXT] | |||
| human_message_prompt += """Use the following context as your learned knowledge, inside <context></context> XML tags. | |||
| <context> | |||
| {{context}} | |||
| [END CONTEXT] | |||
| </context> | |||
| When answer to user: | |||
| - If you don't know, just say that you don't know. | |||
| @@ -219,7 +224,7 @@ And answer according to the language of the user's question. | |||
| if pre_prompt: | |||
| human_message_prompt += pre_prompt | |||
| query_prompt = "\nHuman: {{query}}\nAI: " | |||
| query_prompt = "\n\nHuman: {{query}}\n\nAssistant: " | |||
| if memory: | |||
| # append chat histories | |||
| @@ -228,9 +233,11 @@ And answer according to the language of the user's question. | |||
| inputs=human_inputs | |||
| ) | |||
| curr_message_tokens = memory.llm.get_messages_tokens([tmp_human_message]) | |||
| rest_tokens = llm_constant.max_context_token_length[memory.llm.model_name] \ | |||
| - memory.llm.max_tokens - curr_message_tokens | |||
| curr_message_tokens = memory.llm.get_num_tokens_from_messages([tmp_human_message]) | |||
| model_name = model['name'] | |||
| max_tokens = model.get("completion_params").get('max_tokens') | |||
| rest_tokens = llm_constant.max_context_token_length[model_name] \ | |||
| - max_tokens - curr_message_tokens | |||
| rest_tokens = max(rest_tokens, 0) | |||
| histories = cls.get_history_messages_from_memory(memory, rest_tokens) | |||
| @@ -241,7 +248,10 @@ And answer according to the language of the user's question. | |||
| # if histories_param not in human_inputs: | |||
| # human_inputs[histories_param] = '{{' + histories_param + '}}' | |||
| human_message_prompt += "\n\n" + histories | |||
| human_message_prompt += "\n\n" if human_message_prompt else "" | |||
| human_message_prompt += "Here is the chat histories between human and assistant, " \ | |||
| "inside <histories></histories> XML tags.\n\n<histories>" | |||
| human_message_prompt += histories + "</histories>" | |||
| human_message_prompt += query_prompt | |||
| @@ -307,13 +317,15 @@ And answer according to the language of the user's question. | |||
| model=app_model_config.model_dict | |||
| ) | |||
| model_limited_tokens = llm_constant.max_context_token_length[llm.model_name] | |||
| max_tokens = llm.max_tokens | |||
| model_name = app_model_config.model_dict.get("name") | |||
| model_limited_tokens = llm_constant.max_context_token_length[model_name] | |||
| max_tokens = app_model_config.model_dict.get("completion_params").get('max_tokens') | |||
| # get prompt without memory and context | |||
| prompt, _ = cls.get_main_llm_prompt( | |||
| mode=mode, | |||
| llm=llm, | |||
| model=app_model_config.model_dict, | |||
| pre_prompt=app_model_config.pre_prompt, | |||
| query=query, | |||
| inputs=inputs, | |||
| @@ -332,16 +344,17 @@ And answer according to the language of the user's question. | |||
| return rest_tokens | |||
| @classmethod | |||
| def recale_llm_max_tokens(cls, final_llm: Union[StreamableOpenAI, StreamableChatOpenAI], | |||
| def recale_llm_max_tokens(cls, final_llm: BaseLanguageModel, model: dict, | |||
| prompt: Union[str, List[BaseMessage]], mode: str): | |||
| # recalc max_tokens if sum(prompt_token + max_tokens) over model token limit | |||
| model_limited_tokens = llm_constant.max_context_token_length[final_llm.model_name] | |||
| max_tokens = final_llm.max_tokens | |||
| model_name = model.get("name") | |||
| model_limited_tokens = llm_constant.max_context_token_length[model_name] | |||
| max_tokens = model.get("completion_params").get('max_tokens') | |||
| if mode == 'completion' and isinstance(final_llm, BaseLLM): | |||
| prompt_tokens = final_llm.get_num_tokens(prompt) | |||
| else: | |||
| prompt_tokens = final_llm.get_messages_tokens(prompt) | |||
| prompt_tokens = final_llm.get_num_tokens_from_messages(prompt) | |||
| if prompt_tokens + max_tokens > model_limited_tokens: | |||
| max_tokens = max(model_limited_tokens - prompt_tokens, 16) | |||
| @@ -350,9 +363,10 @@ And answer according to the language of the user's question. | |||
| @classmethod | |||
| def generate_more_like_this(cls, task_id: str, app: App, message: Message, pre_prompt: str, | |||
| app_model_config: AppModelConfig, user: Account, streaming: bool): | |||
| llm: StreamableOpenAI = LLMBuilder.to_llm( | |||
| llm = LLMBuilder.to_llm_from_model( | |||
| tenant_id=app.tenant_id, | |||
| model_name='gpt-3.5-turbo', | |||
| model=app_model_config.model_dict, | |||
| streaming=streaming | |||
| ) | |||
| @@ -360,6 +374,7 @@ And answer according to the language of the user's question. | |||
| original_prompt, _ = cls.get_main_llm_prompt( | |||
| mode="completion", | |||
| llm=llm, | |||
| model=app_model_config.model_dict, | |||
| pre_prompt=pre_prompt, | |||
| query=message.query, | |||
| inputs=message.inputs, | |||
| @@ -390,6 +405,7 @@ And answer according to the language of the user's question. | |||
| cls.recale_llm_max_tokens( | |||
| final_llm=llm, | |||
| model=app_model_config.model_dict, | |||
| prompt=prompt, | |||
| mode='completion' | |||
| ) | |||
| @@ -1,6 +1,8 @@ | |||
| from _decimal import Decimal | |||
| models = { | |||
| 'claude-instant-1': 'anthropic', # 100,000 tokens | |||
| 'claude-2': 'anthropic', # 100,000 tokens | |||
| 'gpt-4': 'openai', # 8,192 tokens | |||
| 'gpt-4-32k': 'openai', # 32,768 tokens | |||
| 'gpt-3.5-turbo': 'openai', # 4,096 tokens | |||
| @@ -10,10 +12,13 @@ models = { | |||
| 'text-curie-001': 'openai', # 2,049 tokens | |||
| 'text-babbage-001': 'openai', # 2,049 tokens | |||
| 'text-ada-001': 'openai', # 2,049 tokens | |||
| 'text-embedding-ada-002': 'openai' # 8191 tokens, 1536 dimensions | |||
| 'text-embedding-ada-002': 'openai', # 8191 tokens, 1536 dimensions | |||
| 'whisper-1': 'openai' | |||
| } | |||
| max_context_token_length = { | |||
| 'claude-instant-1': 100000, | |||
| 'claude-2': 100000, | |||
| 'gpt-4': 8192, | |||
| 'gpt-4-32k': 32768, | |||
| 'gpt-3.5-turbo': 4096, | |||
| @@ -23,17 +28,21 @@ max_context_token_length = { | |||
| 'text-curie-001': 2049, | |||
| 'text-babbage-001': 2049, | |||
| 'text-ada-001': 2049, | |||
| 'text-embedding-ada-002': 8191 | |||
| 'text-embedding-ada-002': 8191, | |||
| } | |||
| models_by_mode = { | |||
| 'chat': [ | |||
| 'claude-instant-1', # 100,000 tokens | |||
| 'claude-2', # 100,000 tokens | |||
| 'gpt-4', # 8,192 tokens | |||
| 'gpt-4-32k', # 32,768 tokens | |||
| 'gpt-3.5-turbo', # 4,096 tokens | |||
| 'gpt-3.5-turbo-16k', # 16,384 tokens | |||
| ], | |||
| 'completion': [ | |||
| 'claude-instant-1', # 100,000 tokens | |||
| 'claude-2', # 100,000 tokens | |||
| 'gpt-4', # 8,192 tokens | |||
| 'gpt-4-32k', # 32,768 tokens | |||
| 'gpt-3.5-turbo', # 4,096 tokens | |||
| @@ -52,6 +61,14 @@ models_by_mode = { | |||
| model_currency = 'USD' | |||
| model_prices = { | |||
| 'claude-instant-1': { | |||
| 'prompt': Decimal('0.00163'), | |||
| 'completion': Decimal('0.00551'), | |||
| }, | |||
| 'claude-2': { | |||
| 'prompt': Decimal('0.01102'), | |||
| 'completion': Decimal('0.03268'), | |||
| }, | |||
| 'gpt-4': { | |||
| 'prompt': Decimal('0.03'), | |||
| 'completion': Decimal('0.06'), | |||
| @@ -56,7 +56,7 @@ class ConversationMessageTask: | |||
| ) | |||
| def init(self): | |||
| provider_name = LLMBuilder.get_default_provider(self.app.tenant_id) | |||
| provider_name = LLMBuilder.get_default_provider(self.app.tenant_id, self.model_name) | |||
| self.model_dict['provider'] = provider_name | |||
| override_model_configs = None | |||
| @@ -89,7 +89,7 @@ class ConversationMessageTask: | |||
| system_message = PromptBuilder.to_system_message(self.app_model_config.pre_prompt, self.inputs) | |||
| system_instruction = system_message.content | |||
| llm = LLMBuilder.to_llm(self.tenant_id, self.model_name) | |||
| system_instruction_tokens = llm.get_messages_tokens([system_message]) | |||
| system_instruction_tokens = llm.get_num_tokens_from_messages([system_message]) | |||
| if not self.conversation: | |||
| self.is_new_conversation = True | |||
| @@ -185,6 +185,7 @@ class ConversationMessageTask: | |||
| if provider and provider.provider_type == ProviderType.SYSTEM.value: | |||
| db.session.query(Provider).filter( | |||
| Provider.tenant_id == self.app.tenant_id, | |||
| Provider.provider_name == provider.provider_name, | |||
| Provider.quota_limit > Provider.quota_used | |||
| ).update({'quota_used': Provider.quota_used + 1}) | |||
| @@ -4,6 +4,7 @@ from typing import List | |||
| from langchain.embeddings.base import Embeddings | |||
| from sqlalchemy.exc import IntegrityError | |||
| from core.llm.wrappers.openai_wrapper import handle_openai_exceptions | |||
| from extensions.ext_database import db | |||
| from libs import helper | |||
| from models.dataset import Embedding | |||
| @@ -49,6 +50,7 @@ class CacheEmbedding(Embeddings): | |||
| text_embeddings.extend(embedding_results) | |||
| return text_embeddings | |||
| @handle_openai_exceptions | |||
| def embed_query(self, text: str) -> List[float]: | |||
| """Embed query text.""" | |||
| # use doc embedding cache or store if not exists | |||
| @@ -23,6 +23,10 @@ class LLMGenerator: | |||
| @classmethod | |||
| def generate_conversation_name(cls, tenant_id: str, query, answer): | |||
| prompt = CONVERSATION_TITLE_PROMPT | |||
| if len(query) > 2000: | |||
| query = query[:300] + "...[TRUNCATED]..." + query[-300:] | |||
| prompt = prompt.format(query=query) | |||
| llm: StreamableOpenAI = LLMBuilder.to_llm( | |||
| tenant_id=tenant_id, | |||
| @@ -52,7 +56,17 @@ class LLMGenerator: | |||
| if not message.answer: | |||
| continue | |||
| message_qa_text = "Human:" + message.query + "\nAI:" + message.answer + "\n" | |||
| if len(message.query) > 2000: | |||
| query = message.query[:300] + "...[TRUNCATED]..." + message.query[-300:] | |||
| else: | |||
| query = message.query | |||
| if len(message.answer) > 2000: | |||
| answer = message.answer[:300] + "...[TRUNCATED]..." + message.answer[-300:] | |||
| else: | |||
| answer = message.answer | |||
| message_qa_text = "\n\nHuman:" + query + "\n\nAssistant:" + answer | |||
| if rest_tokens - TokenCalculator.get_num_tokens(model, context + message_qa_text) > 0: | |||
| context += message_qa_text | |||
| @@ -17,7 +17,7 @@ class IndexBuilder: | |||
| model_credentials = LLMBuilder.get_model_credentials( | |||
| tenant_id=dataset.tenant_id, | |||
| model_provider=LLMBuilder.get_default_provider(dataset.tenant_id), | |||
| model_provider=LLMBuilder.get_default_provider(dataset.tenant_id, 'text-embedding-ada-002'), | |||
| model_name='text-embedding-ada-002' | |||
| ) | |||
| @@ -40,6 +40,9 @@ class ProviderTokenNotInitError(Exception): | |||
| """ | |||
| description = "Provider Token Not Init" | |||
| def __init__(self, *args, **kwargs): | |||
| self.description = args[0] if args else self.description | |||
| class QuotaExceededError(Exception): | |||
| """ | |||
| @@ -8,9 +8,10 @@ from core.llm.provider.base import BaseProvider | |||
| from core.llm.provider.llm_provider_service import LLMProviderService | |||
| from core.llm.streamable_azure_chat_open_ai import StreamableAzureChatOpenAI | |||
| from core.llm.streamable_azure_open_ai import StreamableAzureOpenAI | |||
| from core.llm.streamable_chat_anthropic import StreamableChatAnthropic | |||
| from core.llm.streamable_chat_open_ai import StreamableChatOpenAI | |||
| from core.llm.streamable_open_ai import StreamableOpenAI | |||
| from models.provider import ProviderType | |||
| from models.provider import ProviderType, ProviderName | |||
| class LLMBuilder: | |||
| @@ -32,43 +33,43 @@ class LLMBuilder: | |||
| @classmethod | |||
| def to_llm(cls, tenant_id: str, model_name: str, **kwargs) -> Union[StreamableOpenAI, StreamableChatOpenAI]: | |||
| provider = cls.get_default_provider(tenant_id) | |||
| provider = cls.get_default_provider(tenant_id, model_name) | |||
| model_credentials = cls.get_model_credentials(tenant_id, provider, model_name) | |||
| llm_cls = None | |||
| mode = cls.get_mode_by_model(model_name) | |||
| if mode == 'chat': | |||
| if provider == 'openai': | |||
| if provider == ProviderName.OPENAI.value: | |||
| llm_cls = StreamableChatOpenAI | |||
| else: | |||
| elif provider == ProviderName.AZURE_OPENAI.value: | |||
| llm_cls = StreamableAzureChatOpenAI | |||
| elif provider == ProviderName.ANTHROPIC.value: | |||
| llm_cls = StreamableChatAnthropic | |||
| elif mode == 'completion': | |||
| if provider == 'openai': | |||
| if provider == ProviderName.OPENAI.value: | |||
| llm_cls = StreamableOpenAI | |||
| else: | |||
| elif provider == ProviderName.AZURE_OPENAI.value: | |||
| llm_cls = StreamableAzureOpenAI | |||
| else: | |||
| raise ValueError(f"model name {model_name} is not supported.") | |||
| if not llm_cls: | |||
| raise ValueError(f"model name {model_name} is not supported.") | |||
| model_kwargs = { | |||
| 'model_name': model_name, | |||
| 'temperature': kwargs.get('temperature', 0), | |||
| 'max_tokens': kwargs.get('max_tokens', 256), | |||
| 'top_p': kwargs.get('top_p', 1), | |||
| 'frequency_penalty': kwargs.get('frequency_penalty', 0), | |||
| 'presence_penalty': kwargs.get('presence_penalty', 0), | |||
| 'callbacks': kwargs.get('callbacks', None), | |||
| 'streaming': kwargs.get('streaming', False), | |||
| } | |||
| model_extras_kwargs = model_kwargs if mode == 'completion' else {'model_kwargs': model_kwargs} | |||
| model_kwargs.update(model_credentials) | |||
| model_kwargs = llm_cls.get_kwargs_from_model_params(model_kwargs) | |||
| return llm_cls( | |||
| model_name=model_name, | |||
| temperature=kwargs.get('temperature', 0), | |||
| max_tokens=kwargs.get('max_tokens', 256), | |||
| **model_extras_kwargs, | |||
| callbacks=kwargs.get('callbacks', None), | |||
| streaming=kwargs.get('streaming', False), | |||
| # request_timeout=None | |||
| **model_credentials | |||
| ) | |||
| return llm_cls(**model_kwargs) | |||
| @classmethod | |||
| def to_llm_from_model(cls, tenant_id: str, model: dict, streaming: bool = False, | |||
| @@ -118,14 +119,29 @@ class LLMBuilder: | |||
| return provider_service.get_credentials(model_name) | |||
| @classmethod | |||
| def get_default_provider(cls, tenant_id: str) -> str: | |||
| provider = BaseProvider.get_valid_provider(tenant_id) | |||
| if not provider: | |||
| raise ProviderTokenNotInitError() | |||
| if provider.provider_type == ProviderType.SYSTEM.value: | |||
| provider_name = 'openai' | |||
| else: | |||
| provider_name = provider.provider_name | |||
| def get_default_provider(cls, tenant_id: str, model_name: str) -> str: | |||
| provider_name = llm_constant.models[model_name] | |||
| if provider_name == 'openai': | |||
| # get the default provider (openai / azure_openai) for the tenant | |||
| openai_provider = BaseProvider.get_valid_provider(tenant_id, ProviderName.OPENAI.value) | |||
| azure_openai_provider = BaseProvider.get_valid_provider(tenant_id, ProviderName.AZURE_OPENAI.value) | |||
| provider = None | |||
| if openai_provider: | |||
| provider = openai_provider | |||
| elif azure_openai_provider: | |||
| provider = azure_openai_provider | |||
| if not provider: | |||
| raise ProviderTokenNotInitError( | |||
| f"No valid {provider_name} model provider credentials found. " | |||
| f"Please go to Settings -> Model Provider to complete your provider credentials." | |||
| ) | |||
| if provider.provider_type == ProviderType.SYSTEM.value: | |||
| provider_name = 'openai' | |||
| else: | |||
| provider_name = provider.provider_name | |||
| return provider_name | |||
| @@ -1,23 +1,138 @@ | |||
| from typing import Optional | |||
| import json | |||
| import logging | |||
| from typing import Optional, Union | |||
| import anthropic | |||
| from langchain.chat_models import ChatAnthropic | |||
| from langchain.schema import HumanMessage | |||
| from core import hosted_llm_credentials | |||
| from core.llm.error import ProviderTokenNotInitError | |||
| from core.llm.provider.base import BaseProvider | |||
| from models.provider import ProviderName | |||
| from core.llm.provider.errors import ValidateFailedError | |||
| from models.provider import ProviderName, ProviderType | |||
| class AnthropicProvider(BaseProvider): | |||
| def get_models(self, model_id: Optional[str] = None) -> list[dict]: | |||
| credentials = self.get_credentials(model_id) | |||
| # todo | |||
| return [] | |||
| return [ | |||
| { | |||
| 'id': 'claude-instant-1', | |||
| 'name': 'claude-instant-1', | |||
| }, | |||
| { | |||
| 'id': 'claude-2', | |||
| 'name': 'claude-2', | |||
| }, | |||
| ] | |||
| def get_credentials(self, model_id: Optional[str] = None) -> dict: | |||
| return self.get_provider_api_key(model_id=model_id) | |||
| def get_provider_name(self): | |||
| return ProviderName.ANTHROPIC | |||
| def get_provider_configs(self, obfuscated: bool = False, only_custom: bool = False) -> Union[str | dict]: | |||
| """ | |||
| Returns the API credentials for Azure OpenAI as a dictionary, for the given tenant_id. | |||
| The dictionary contains keys: azure_api_type, azure_api_version, azure_api_base, and azure_api_key. | |||
| Returns the provider configs. | |||
| """ | |||
| return { | |||
| 'anthropic_api_key': self.get_provider_api_key(model_id=model_id) | |||
| } | |||
| try: | |||
| config = self.get_provider_api_key(only_custom=only_custom) | |||
| except: | |||
| config = { | |||
| 'anthropic_api_key': '' | |||
| } | |||
| def get_provider_name(self): | |||
| return ProviderName.ANTHROPIC | |||
| if obfuscated: | |||
| if not config.get('anthropic_api_key'): | |||
| config = { | |||
| 'anthropic_api_key': '' | |||
| } | |||
| config['anthropic_api_key'] = self.obfuscated_token(config.get('anthropic_api_key')) | |||
| return config | |||
| return config | |||
| def get_encrypted_token(self, config: Union[dict | str]): | |||
| """ | |||
| Returns the encrypted token. | |||
| """ | |||
| return json.dumps({ | |||
| 'anthropic_api_key': self.encrypt_token(config['anthropic_api_key']) | |||
| }) | |||
| def get_decrypted_token(self, token: str): | |||
| """ | |||
| Returns the decrypted token. | |||
| """ | |||
| config = json.loads(token) | |||
| config['anthropic_api_key'] = self.decrypt_token(config['anthropic_api_key']) | |||
| return config | |||
| def get_token_type(self): | |||
| return dict | |||
| def config_validate(self, config: Union[dict | str]): | |||
| """ | |||
| Validates the given config. | |||
| """ | |||
| # check OpenAI / Azure OpenAI credential is valid | |||
| openai_provider = BaseProvider.get_valid_provider(self.tenant_id, ProviderName.OPENAI.value) | |||
| azure_openai_provider = BaseProvider.get_valid_provider(self.tenant_id, ProviderName.AZURE_OPENAI.value) | |||
| provider = None | |||
| if openai_provider: | |||
| provider = openai_provider | |||
| elif azure_openai_provider: | |||
| provider = azure_openai_provider | |||
| if not provider: | |||
| raise ValidateFailedError(f"OpenAI or Azure OpenAI provider must be configured first.") | |||
| if provider.provider_type == ProviderType.SYSTEM.value: | |||
| quota_used = provider.quota_used if provider.quota_used is not None else 0 | |||
| quota_limit = provider.quota_limit if provider.quota_limit is not None else 0 | |||
| if quota_used >= quota_limit: | |||
| raise ValidateFailedError(f"Your quota for Dify Hosted OpenAI has been exhausted, " | |||
| f"please configure OpenAI or Azure OpenAI provider first.") | |||
| try: | |||
| if not isinstance(config, dict): | |||
| raise ValueError('Config must be a object.') | |||
| if 'anthropic_api_key' not in config: | |||
| raise ValueError('anthropic_api_key must be provided.') | |||
| chat_llm = ChatAnthropic( | |||
| model='claude-instant-1', | |||
| anthropic_api_key=config['anthropic_api_key'], | |||
| max_tokens_to_sample=10, | |||
| temperature=0, | |||
| default_request_timeout=60 | |||
| ) | |||
| messages = [ | |||
| HumanMessage( | |||
| content="ping" | |||
| ) | |||
| ] | |||
| chat_llm(messages) | |||
| except anthropic.APIConnectionError as ex: | |||
| raise ValidateFailedError(f"Anthropic: Connection error, cause: {ex.__cause__}") | |||
| except (anthropic.APIStatusError, anthropic.RateLimitError) as ex: | |||
| raise ValidateFailedError(f"Anthropic: Error code: {ex.status_code} - " | |||
| f"{ex.body['error']['type']}: {ex.body['error']['message']}") | |||
| except Exception as ex: | |||
| logging.exception('Anthropic config validation failed') | |||
| raise ex | |||
| def get_hosted_credentials(self) -> Union[str | dict]: | |||
| if not hosted_llm_credentials.anthropic or not hosted_llm_credentials.anthropic.api_key: | |||
| raise ProviderTokenNotInitError( | |||
| f"No valid {self.get_provider_name().value} model provider credentials found. " | |||
| f"Please go to Settings -> Model Provider to complete your provider credentials." | |||
| ) | |||
| return {'anthropic_api_key': hosted_llm_credentials.anthropic.api_key} | |||
| @@ -52,12 +52,12 @@ class AzureProvider(BaseProvider): | |||
| def get_provider_name(self): | |||
| return ProviderName.AZURE_OPENAI | |||
| def get_provider_configs(self, obfuscated: bool = False) -> Union[str | dict]: | |||
| def get_provider_configs(self, obfuscated: bool = False, only_custom: bool = False) -> Union[str | dict]: | |||
| """ | |||
| Returns the provider configs. | |||
| """ | |||
| try: | |||
| config = self.get_provider_api_key() | |||
| config = self.get_provider_api_key(only_custom=only_custom) | |||
| except: | |||
| config = { | |||
| 'openai_api_type': 'azure', | |||
| @@ -81,7 +81,6 @@ class AzureProvider(BaseProvider): | |||
| return config | |||
| def get_token_type(self): | |||
| # TODO: change to dict when implemented | |||
| return dict | |||
| def config_validate(self, config: Union[dict | str]): | |||
| @@ -2,7 +2,7 @@ import base64 | |||
| from abc import ABC, abstractmethod | |||
| from typing import Optional, Union | |||
| from core import hosted_llm_credentials | |||
| from core.constant import llm_constant | |||
| from core.llm.error import QuotaExceededError, ModelCurrentlyNotSupportError, ProviderTokenNotInitError | |||
| from extensions.ext_database import db | |||
| from libs import rsa | |||
| @@ -14,15 +14,18 @@ class BaseProvider(ABC): | |||
| def __init__(self, tenant_id: str): | |||
| self.tenant_id = tenant_id | |||
| def get_provider_api_key(self, model_id: Optional[str] = None, prefer_custom: bool = True) -> Union[str | dict]: | |||
| def get_provider_api_key(self, model_id: Optional[str] = None, only_custom: bool = False) -> Union[str | dict]: | |||
| """ | |||
| Returns the decrypted API key for the given tenant_id and provider_name. | |||
| If the provider is of type SYSTEM and the quota is exceeded, raises a QuotaExceededError. | |||
| If the provider is not found or not valid, raises a ProviderTokenNotInitError. | |||
| """ | |||
| provider = self.get_provider(prefer_custom) | |||
| provider = self.get_provider(only_custom) | |||
| if not provider: | |||
| raise ProviderTokenNotInitError() | |||
| raise ProviderTokenNotInitError( | |||
| f"No valid {llm_constant.models[model_id]} model provider credentials found. " | |||
| f"Please go to Settings -> Model Provider to complete your provider credentials." | |||
| ) | |||
| if provider.provider_type == ProviderType.SYSTEM.value: | |||
| quota_used = provider.quota_used if provider.quota_used is not None else 0 | |||
| @@ -38,18 +41,19 @@ class BaseProvider(ABC): | |||
| else: | |||
| return self.get_decrypted_token(provider.encrypted_config) | |||
| def get_provider(self, prefer_custom: bool) -> Optional[Provider]: | |||
| def get_provider(self, only_custom: bool = False) -> Optional[Provider]: | |||
| """ | |||
| Returns the Provider instance for the given tenant_id and provider_name. | |||
| If both CUSTOM and System providers exist, the preferred provider will be returned based on the prefer_custom flag. | |||
| """ | |||
| return BaseProvider.get_valid_provider(self.tenant_id, self.get_provider_name().value, prefer_custom) | |||
| return BaseProvider.get_valid_provider(self.tenant_id, self.get_provider_name().value, only_custom) | |||
| @classmethod | |||
| def get_valid_provider(cls, tenant_id: str, provider_name: str = None, prefer_custom: bool = False) -> Optional[Provider]: | |||
| def get_valid_provider(cls, tenant_id: str, provider_name: str = None, only_custom: bool = False) -> Optional[ | |||
| Provider]: | |||
| """ | |||
| Returns the Provider instance for the given tenant_id and provider_name. | |||
| If both CUSTOM and System providers exist, the preferred provider will be returned based on the prefer_custom flag. | |||
| If both CUSTOM and System providers exist. | |||
| """ | |||
| query = db.session.query(Provider).filter( | |||
| Provider.tenant_id == tenant_id | |||
| @@ -58,39 +62,31 @@ class BaseProvider(ABC): | |||
| if provider_name: | |||
| query = query.filter(Provider.provider_name == provider_name) | |||
| providers = query.order_by(Provider.provider_type.desc() if prefer_custom else Provider.provider_type).all() | |||
| if only_custom: | |||
| query = query.filter(Provider.provider_type == ProviderType.CUSTOM.value) | |||
| custom_provider = None | |||
| system_provider = None | |||
| providers = query.order_by(Provider.provider_type.asc()).all() | |||
| for provider in providers: | |||
| if provider.provider_type == ProviderType.CUSTOM.value and provider.is_valid and provider.encrypted_config: | |||
| custom_provider = provider | |||
| return provider | |||
| elif provider.provider_type == ProviderType.SYSTEM.value and provider.is_valid: | |||
| system_provider = provider | |||
| if custom_provider: | |||
| return custom_provider | |||
| elif system_provider: | |||
| return system_provider | |||
| else: | |||
| return None | |||
| return provider | |||
| def get_hosted_credentials(self) -> str: | |||
| if self.get_provider_name() != ProviderName.OPENAI: | |||
| raise ProviderTokenNotInitError() | |||
| return None | |||
| if not hosted_llm_credentials.openai or not hosted_llm_credentials.openai.api_key: | |||
| raise ProviderTokenNotInitError() | |||
| return hosted_llm_credentials.openai.api_key | |||
| def get_hosted_credentials(self) -> Union[str | dict]: | |||
| raise ProviderTokenNotInitError( | |||
| f"No valid {self.get_provider_name().value} model provider credentials found. " | |||
| f"Please go to Settings -> Model Provider to complete your provider credentials." | |||
| ) | |||
| def get_provider_configs(self, obfuscated: bool = False) -> Union[str | dict]: | |||
| def get_provider_configs(self, obfuscated: bool = False, only_custom: bool = False) -> Union[str | dict]: | |||
| """ | |||
| Returns the provider configs. | |||
| """ | |||
| try: | |||
| config = self.get_provider_api_key() | |||
| config = self.get_provider_api_key(only_custom=only_custom) | |||
| except: | |||
| config = '' | |||
| @@ -31,11 +31,11 @@ class LLMProviderService: | |||
| def get_credentials(self, model_id: Optional[str] = None) -> dict: | |||
| return self.provider.get_credentials(model_id) | |||
| def get_provider_configs(self, obfuscated: bool = False) -> Union[str | dict]: | |||
| return self.provider.get_provider_configs(obfuscated) | |||
| def get_provider_configs(self, obfuscated: bool = False, only_custom: bool = False) -> Union[str | dict]: | |||
| return self.provider.get_provider_configs(obfuscated=obfuscated, only_custom=only_custom) | |||
| def get_provider_db_record(self, prefer_custom: bool = False) -> Optional[Provider]: | |||
| return self.provider.get_provider(prefer_custom) | |||
| def get_provider_db_record(self) -> Optional[Provider]: | |||
| return self.provider.get_provider() | |||
| def config_validate(self, config: Union[dict | str]): | |||
| """ | |||
| @@ -4,6 +4,8 @@ from typing import Optional, Union | |||
| import openai | |||
| from openai.error import AuthenticationError, OpenAIError | |||
| from core import hosted_llm_credentials | |||
| from core.llm.error import ProviderTokenNotInitError | |||
| from core.llm.moderation import Moderation | |||
| from core.llm.provider.base import BaseProvider | |||
| from core.llm.provider.errors import ValidateFailedError | |||
| @@ -42,3 +44,12 @@ class OpenAIProvider(BaseProvider): | |||
| except Exception as ex: | |||
| logging.exception('OpenAI config validation failed') | |||
| raise ex | |||
| def get_hosted_credentials(self) -> Union[str | dict]: | |||
| if not hosted_llm_credentials.openai or not hosted_llm_credentials.openai.api_key: | |||
| raise ProviderTokenNotInitError( | |||
| f"No valid {self.get_provider_name().value} model provider credentials found. " | |||
| f"Please go to Settings -> Model Provider to complete your provider credentials." | |||
| ) | |||
| return hosted_llm_credentials.openai.api_key | |||
| @@ -1,11 +1,11 @@ | |||
| from langchain.callbacks.manager import CallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun, Callbacks | |||
| from langchain.schema import BaseMessage, ChatResult, LLMResult | |||
| from langchain.callbacks.manager import Callbacks | |||
| from langchain.schema import BaseMessage, LLMResult | |||
| from langchain.chat_models import AzureChatOpenAI | |||
| from typing import Optional, List, Dict, Any | |||
| from pydantic import root_validator | |||
| from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async | |||
| from core.llm.wrappers.openai_wrapper import handle_openai_exceptions | |||
| class StreamableAzureChatOpenAI(AzureChatOpenAI): | |||
| @@ -46,30 +46,7 @@ class StreamableAzureChatOpenAI(AzureChatOpenAI): | |||
| "organization": self.openai_organization if self.openai_organization else None, | |||
| } | |||
| def get_messages_tokens(self, messages: List[BaseMessage]) -> int: | |||
| """Get the number of tokens in a list of messages. | |||
| Args: | |||
| messages: The messages to count the tokens of. | |||
| Returns: | |||
| The number of tokens in the messages. | |||
| """ | |||
| tokens_per_message = 5 | |||
| tokens_per_request = 3 | |||
| message_tokens = tokens_per_request | |||
| message_strs = '' | |||
| for message in messages: | |||
| message_strs += message.content | |||
| message_tokens += tokens_per_message | |||
| # calc once | |||
| message_tokens += self.get_num_tokens(message_strs) | |||
| return message_tokens | |||
| @handle_llm_exceptions | |||
| @handle_openai_exceptions | |||
| def generate( | |||
| self, | |||
| messages: List[List[BaseMessage]], | |||
| @@ -79,12 +56,18 @@ class StreamableAzureChatOpenAI(AzureChatOpenAI): | |||
| ) -> LLMResult: | |||
| return super().generate(messages, stop, callbacks, **kwargs) | |||
| @handle_llm_exceptions_async | |||
| async def agenerate( | |||
| self, | |||
| messages: List[List[BaseMessage]], | |||
| stop: Optional[List[str]] = None, | |||
| callbacks: Callbacks = None, | |||
| **kwargs: Any, | |||
| ) -> LLMResult: | |||
| return await super().agenerate(messages, stop, callbacks, **kwargs) | |||
| @classmethod | |||
| def get_kwargs_from_model_params(cls, params: dict): | |||
| model_kwargs = { | |||
| 'top_p': params.get('top_p', 1), | |||
| 'frequency_penalty': params.get('frequency_penalty', 0), | |||
| 'presence_penalty': params.get('presence_penalty', 0), | |||
| } | |||
| del params['top_p'] | |||
| del params['frequency_penalty'] | |||
| del params['presence_penalty'] | |||
| params['model_kwargs'] = model_kwargs | |||
| return params | |||
| @@ -5,7 +5,7 @@ from typing import Optional, List, Dict, Mapping, Any | |||
| from pydantic import root_validator | |||
| from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async | |||
| from core.llm.wrappers.openai_wrapper import handle_openai_exceptions | |||
| class StreamableAzureOpenAI(AzureOpenAI): | |||
| @@ -50,7 +50,7 @@ class StreamableAzureOpenAI(AzureOpenAI): | |||
| "organization": self.openai_organization if self.openai_organization else None, | |||
| }} | |||
| @handle_llm_exceptions | |||
| @handle_openai_exceptions | |||
| def generate( | |||
| self, | |||
| prompts: List[str], | |||
| @@ -60,12 +60,6 @@ class StreamableAzureOpenAI(AzureOpenAI): | |||
| ) -> LLMResult: | |||
| return super().generate(prompts, stop, callbacks, **kwargs) | |||
| @handle_llm_exceptions_async | |||
| async def agenerate( | |||
| self, | |||
| prompts: List[str], | |||
| stop: Optional[List[str]] = None, | |||
| callbacks: Callbacks = None, | |||
| **kwargs: Any, | |||
| ) -> LLMResult: | |||
| return await super().agenerate(prompts, stop, callbacks, **kwargs) | |||
| @classmethod | |||
| def get_kwargs_from_model_params(cls, params: dict): | |||
| return params | |||
| @@ -0,0 +1,39 @@ | |||
| from typing import List, Optional, Any, Dict | |||
| from langchain.callbacks.manager import Callbacks | |||
| from langchain.chat_models import ChatAnthropic | |||
| from langchain.schema import BaseMessage, LLMResult | |||
| from core.llm.wrappers.anthropic_wrapper import handle_anthropic_exceptions | |||
| class StreamableChatAnthropic(ChatAnthropic): | |||
| """ | |||
| Wrapper around Anthropic's large language model. | |||
| """ | |||
| @handle_anthropic_exceptions | |||
| def generate( | |||
| self, | |||
| messages: List[List[BaseMessage]], | |||
| stop: Optional[List[str]] = None, | |||
| callbacks: Callbacks = None, | |||
| *, | |||
| tags: Optional[List[str]] = None, | |||
| metadata: Optional[Dict[str, Any]] = None, | |||
| **kwargs: Any, | |||
| ) -> LLMResult: | |||
| return super().generate(messages, stop, callbacks, tags=tags, metadata=metadata, **kwargs) | |||
| @classmethod | |||
| def get_kwargs_from_model_params(cls, params: dict): | |||
| params['model'] = params.get('model_name') | |||
| del params['model_name'] | |||
| params['max_tokens_to_sample'] = params.get('max_tokens') | |||
| del params['max_tokens'] | |||
| del params['frequency_penalty'] | |||
| del params['presence_penalty'] | |||
| return params | |||
| @@ -7,7 +7,7 @@ from typing import Optional, List, Dict, Any | |||
| from pydantic import root_validator | |||
| from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async | |||
| from core.llm.wrappers.openai_wrapper import handle_openai_exceptions | |||
| class StreamableChatOpenAI(ChatOpenAI): | |||
| @@ -48,30 +48,7 @@ class StreamableChatOpenAI(ChatOpenAI): | |||
| "organization": self.openai_organization if self.openai_organization else None, | |||
| } | |||
| def get_messages_tokens(self, messages: List[BaseMessage]) -> int: | |||
| """Get the number of tokens in a list of messages. | |||
| Args: | |||
| messages: The messages to count the tokens of. | |||
| Returns: | |||
| The number of tokens in the messages. | |||
| """ | |||
| tokens_per_message = 5 | |||
| tokens_per_request = 3 | |||
| message_tokens = tokens_per_request | |||
| message_strs = '' | |||
| for message in messages: | |||
| message_strs += message.content | |||
| message_tokens += tokens_per_message | |||
| # calc once | |||
| message_tokens += self.get_num_tokens(message_strs) | |||
| return message_tokens | |||
| @handle_llm_exceptions | |||
| @handle_openai_exceptions | |||
| def generate( | |||
| self, | |||
| messages: List[List[BaseMessage]], | |||
| @@ -81,12 +58,18 @@ class StreamableChatOpenAI(ChatOpenAI): | |||
| ) -> LLMResult: | |||
| return super().generate(messages, stop, callbacks, **kwargs) | |||
| @handle_llm_exceptions_async | |||
| async def agenerate( | |||
| self, | |||
| messages: List[List[BaseMessage]], | |||
| stop: Optional[List[str]] = None, | |||
| callbacks: Callbacks = None, | |||
| **kwargs: Any, | |||
| ) -> LLMResult: | |||
| return await super().agenerate(messages, stop, callbacks, **kwargs) | |||
| @classmethod | |||
| def get_kwargs_from_model_params(cls, params: dict): | |||
| model_kwargs = { | |||
| 'top_p': params.get('top_p', 1), | |||
| 'frequency_penalty': params.get('frequency_penalty', 0), | |||
| 'presence_penalty': params.get('presence_penalty', 0), | |||
| } | |||
| del params['top_p'] | |||
| del params['frequency_penalty'] | |||
| del params['presence_penalty'] | |||
| params['model_kwargs'] = model_kwargs | |||
| return params | |||
| @@ -6,7 +6,7 @@ from typing import Optional, List, Dict, Any, Mapping | |||
| from langchain import OpenAI | |||
| from pydantic import root_validator | |||
| from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async | |||
| from core.llm.wrappers.openai_wrapper import handle_openai_exceptions | |||
| class StreamableOpenAI(OpenAI): | |||
| @@ -49,7 +49,7 @@ class StreamableOpenAI(OpenAI): | |||
| "organization": self.openai_organization if self.openai_organization else None, | |||
| }} | |||
| @handle_llm_exceptions | |||
| @handle_openai_exceptions | |||
| def generate( | |||
| self, | |||
| prompts: List[str], | |||
| @@ -59,12 +59,6 @@ class StreamableOpenAI(OpenAI): | |||
| ) -> LLMResult: | |||
| return super().generate(prompts, stop, callbacks, **kwargs) | |||
| @handle_llm_exceptions_async | |||
| async def agenerate( | |||
| self, | |||
| prompts: List[str], | |||
| stop: Optional[List[str]] = None, | |||
| callbacks: Callbacks = None, | |||
| **kwargs: Any, | |||
| ) -> LLMResult: | |||
| return await super().agenerate(prompts, stop, callbacks, **kwargs) | |||
| @classmethod | |||
| def get_kwargs_from_model_params(cls, params: dict): | |||
| return params | |||
| @@ -1,6 +1,7 @@ | |||
| import openai | |||
| from core.llm.wrappers.openai_wrapper import handle_openai_exceptions | |||
| from models.provider import ProviderName | |||
| from core.llm.error_handle_wraps import handle_llm_exceptions | |||
| from core.llm.provider.base import BaseProvider | |||
| @@ -13,7 +14,7 @@ class Whisper: | |||
| self.client = openai.Audio | |||
| self.credentials = provider.get_credentials() | |||
| @handle_llm_exceptions | |||
| @handle_openai_exceptions | |||
| def transcribe(self, file): | |||
| return self.client.transcribe( | |||
| model='whisper-1', | |||
| @@ -0,0 +1,27 @@ | |||
| import logging | |||
| from functools import wraps | |||
| import anthropic | |||
| from core.llm.error import LLMAPIConnectionError, LLMAPIUnavailableError, LLMRateLimitError, LLMAuthorizationError, \ | |||
| LLMBadRequestError | |||
| def handle_anthropic_exceptions(func): | |||
| @wraps(func) | |||
| def wrapper(*args, **kwargs): | |||
| try: | |||
| return func(*args, **kwargs) | |||
| except anthropic.APIConnectionError as e: | |||
| logging.exception("Failed to connect to Anthropic API.") | |||
| raise LLMAPIConnectionError(f"Anthropic: The server could not be reached, cause: {e.__cause__}") | |||
| except anthropic.RateLimitError: | |||
| raise LLMRateLimitError("Anthropic: A 429 status code was received; we should back off a bit.") | |||
| except anthropic.AuthenticationError as e: | |||
| raise LLMAuthorizationError(f"Anthropic: {e.message}") | |||
| except anthropic.BadRequestError as e: | |||
| raise LLMBadRequestError(f"Anthropic: {e.message}") | |||
| except anthropic.APIStatusError as e: | |||
| raise LLMAPIUnavailableError(f"Anthropic: code: {e.status_code}, cause: {e.message}") | |||
| return wrapper | |||
| @@ -7,7 +7,7 @@ from core.llm.error import LLMAPIConnectionError, LLMAPIUnavailableError, LLMRat | |||
| LLMBadRequestError | |||
| def handle_llm_exceptions(func): | |||
| def handle_openai_exceptions(func): | |||
| @wraps(func) | |||
| def wrapper(*args, **kwargs): | |||
| try: | |||
| @@ -29,27 +29,3 @@ def handle_llm_exceptions(func): | |||
| raise LLMBadRequestError(e.__class__.__name__ + ":" + str(e)) | |||
| return wrapper | |||
| def handle_llm_exceptions_async(func): | |||
| @wraps(func) | |||
| async def wrapper(*args, **kwargs): | |||
| try: | |||
| return await func(*args, **kwargs) | |||
| except openai.error.InvalidRequestError as e: | |||
| logging.exception("Invalid request to OpenAI API.") | |||
| raise LLMBadRequestError(str(e)) | |||
| except openai.error.APIConnectionError as e: | |||
| logging.exception("Failed to connect to OpenAI API.") | |||
| raise LLMAPIConnectionError(e.__class__.__name__ + ":" + str(e)) | |||
| except (openai.error.APIError, openai.error.ServiceUnavailableError, openai.error.Timeout) as e: | |||
| logging.exception("OpenAI service unavailable.") | |||
| raise LLMAPIUnavailableError(e.__class__.__name__ + ":" + str(e)) | |||
| except openai.error.RateLimitError as e: | |||
| raise LLMRateLimitError(str(e)) | |||
| except openai.error.AuthenticationError as e: | |||
| raise LLMAuthorizationError(str(e)) | |||
| except openai.error.OpenAIError as e: | |||
| raise LLMBadRequestError(e.__class__.__name__ + ":" + str(e)) | |||
| return wrapper | |||
| @@ -1,7 +1,7 @@ | |||
| from typing import Any, List, Dict, Union | |||
| from langchain.memory.chat_memory import BaseChatMemory | |||
| from langchain.schema import get_buffer_string, BaseMessage, HumanMessage, AIMessage | |||
| from langchain.schema import get_buffer_string, BaseMessage, HumanMessage, AIMessage, BaseLanguageModel | |||
| from core.llm.streamable_chat_open_ai import StreamableChatOpenAI | |||
| from core.llm.streamable_open_ai import StreamableOpenAI | |||
| @@ -12,8 +12,8 @@ from models.model import Conversation, Message | |||
| class ReadOnlyConversationTokenDBBufferSharedMemory(BaseChatMemory): | |||
| conversation: Conversation | |||
| human_prefix: str = "Human" | |||
| ai_prefix: str = "AI" | |||
| llm: Union[StreamableChatOpenAI | StreamableOpenAI] | |||
| ai_prefix: str = "Assistant" | |||
| llm: BaseLanguageModel | |||
| memory_key: str = "chat_history" | |||
| max_token_limit: int = 2000 | |||
| message_limit: int = 10 | |||
| @@ -38,12 +38,12 @@ class ReadOnlyConversationTokenDBBufferSharedMemory(BaseChatMemory): | |||
| return chat_messages | |||
| # prune the chat message if it exceeds the max token limit | |||
| curr_buffer_length = self.llm.get_messages_tokens(chat_messages) | |||
| curr_buffer_length = self.llm.get_num_tokens_from_messages(chat_messages) | |||
| if curr_buffer_length > self.max_token_limit: | |||
| pruned_memory = [] | |||
| while curr_buffer_length > self.max_token_limit and chat_messages: | |||
| pruned_memory.append(chat_messages.pop(0)) | |||
| curr_buffer_length = self.llm.get_messages_tokens(chat_messages) | |||
| curr_buffer_length = self.llm.get_num_tokens_from_messages(chat_messages) | |||
| return chat_messages | |||
| @@ -30,7 +30,7 @@ class DatasetTool(BaseTool): | |||
| else: | |||
| model_credentials = LLMBuilder.get_model_credentials( | |||
| tenant_id=self.dataset.tenant_id, | |||
| model_provider=LLMBuilder.get_default_provider(self.dataset.tenant_id), | |||
| model_provider=LLMBuilder.get_default_provider(self.dataset.tenant_id, 'text-embedding-ada-002'), | |||
| model_name='text-embedding-ada-002' | |||
| ) | |||
| @@ -60,7 +60,7 @@ class DatasetTool(BaseTool): | |||
| async def _arun(self, tool_input: str) -> str: | |||
| model_credentials = LLMBuilder.get_model_credentials( | |||
| tenant_id=self.dataset.tenant_id, | |||
| model_provider=LLMBuilder.get_default_provider(self.dataset.tenant_id), | |||
| model_provider=LLMBuilder.get_default_provider(self.dataset.tenant_id, 'text-embedding-ada-002'), | |||
| model_name='text-embedding-ada-002' | |||
| ) | |||
| @@ -1,4 +1,7 @@ | |||
| from flask import current_app | |||
| from events.tenant_event import tenant_was_updated | |||
| from models.provider import ProviderName | |||
| from services.provider_service import ProviderService | |||
| @@ -6,4 +9,16 @@ from services.provider_service import ProviderService | |||
| def handle(sender, **kwargs): | |||
| tenant = sender | |||
| if tenant.status == 'normal': | |||
| ProviderService.create_system_provider(tenant) | |||
| ProviderService.create_system_provider( | |||
| tenant, | |||
| ProviderName.OPENAI.value, | |||
| current_app.config['OPENAI_HOSTED_QUOTA_LIMIT'], | |||
| True | |||
| ) | |||
| ProviderService.create_system_provider( | |||
| tenant, | |||
| ProviderName.ANTHROPIC.value, | |||
| current_app.config['ANTHROPIC_HOSTED_QUOTA_LIMIT'], | |||
| True | |||
| ) | |||
| @@ -1,4 +1,7 @@ | |||
| from flask import current_app | |||
| from events.tenant_event import tenant_was_created | |||
| from models.provider import ProviderName | |||
| from services.provider_service import ProviderService | |||
| @@ -6,4 +9,16 @@ from services.provider_service import ProviderService | |||
| def handle(sender, **kwargs): | |||
| tenant = sender | |||
| if tenant.status == 'normal': | |||
| ProviderService.create_system_provider(tenant) | |||
| ProviderService.create_system_provider( | |||
| tenant, | |||
| ProviderName.OPENAI.value, | |||
| current_app.config['OPENAI_HOSTED_QUOTA_LIMIT'], | |||
| True | |||
| ) | |||
| ProviderService.create_system_provider( | |||
| tenant, | |||
| ProviderName.ANTHROPIC.value, | |||
| current_app.config['ANTHROPIC_HOSTED_QUOTA_LIMIT'], | |||
| True | |||
| ) | |||
| @@ -10,7 +10,7 @@ flask-session2==1.3.1 | |||
| flask-cors==3.0.10 | |||
| gunicorn~=20.1.0 | |||
| gevent~=22.10.2 | |||
| langchain==0.0.209 | |||
| langchain==0.0.230 | |||
| openai~=0.27.5 | |||
| psycopg2-binary~=2.9.6 | |||
| pycryptodome==3.17 | |||
| @@ -35,3 +35,4 @@ docx2txt==0.8 | |||
| pypdfium2==4.16.0 | |||
| resend~=0.5.1 | |||
| pyjwt~=2.6.0 | |||
| anthropic~=0.3.4 | |||
| @@ -6,6 +6,30 @@ from models.account import Account | |||
| from services.dataset_service import DatasetService | |||
| from core.llm.llm_builder import LLMBuilder | |||
| MODEL_PROVIDERS = [ | |||
| 'openai', | |||
| 'anthropic', | |||
| ] | |||
| MODELS_BY_APP_MODE = { | |||
| 'chat': [ | |||
| 'claude-instant-1', | |||
| 'claude-2', | |||
| 'gpt-4', | |||
| 'gpt-4-32k', | |||
| 'gpt-3.5-turbo', | |||
| 'gpt-3.5-turbo-16k', | |||
| ], | |||
| 'completion': [ | |||
| 'claude-instant-1', | |||
| 'claude-2', | |||
| 'gpt-4', | |||
| 'gpt-4-32k', | |||
| 'gpt-3.5-turbo', | |||
| 'gpt-3.5-turbo-16k', | |||
| 'text-davinci-003', | |||
| ] | |||
| } | |||
| class AppModelConfigService: | |||
| @staticmethod | |||
| @@ -125,7 +149,7 @@ class AppModelConfigService: | |||
| if not isinstance(config["speech_to_text"]["enabled"], bool): | |||
| raise ValueError("enabled in speech_to_text must be of boolean type") | |||
| provider_name = LLMBuilder.get_default_provider(account.current_tenant_id) | |||
| provider_name = LLMBuilder.get_default_provider(account.current_tenant_id, 'whisper-1') | |||
| if config["speech_to_text"]["enabled"] and provider_name != 'openai': | |||
| raise ValueError("provider not support speech to text") | |||
| @@ -153,14 +177,14 @@ class AppModelConfigService: | |||
| raise ValueError("model must be of object type") | |||
| # model.provider | |||
| if 'provider' not in config["model"] or config["model"]["provider"] != "openai": | |||
| raise ValueError("model.provider must be 'openai'") | |||
| if 'provider' not in config["model"] or config["model"]["provider"] not in MODEL_PROVIDERS: | |||
| raise ValueError(f"model.provider is required and must be in {str(MODEL_PROVIDERS)}") | |||
| # model.name | |||
| if 'name' not in config["model"]: | |||
| raise ValueError("model.name is required") | |||
| if config["model"]["name"] not in llm_constant.models_by_mode[mode]: | |||
| if config["model"]["name"] not in MODELS_BY_APP_MODE[mode]: | |||
| raise ValueError("model.name must be in the specified model list") | |||
| # model.completion_params | |||
| @@ -27,7 +27,7 @@ class AudioService: | |||
| message = f"Audio size larger than {FILE_SIZE} mb" | |||
| raise AudioTooLargeServiceError(message) | |||
| provider_name = LLMBuilder.get_default_provider(tenant_id) | |||
| provider_name = LLMBuilder.get_default_provider(tenant_id, 'whisper-1') | |||
| if provider_name != ProviderName.OPENAI.value: | |||
| raise ProviderNotSupportSpeechToTextServiceError() | |||
| @@ -37,8 +37,3 @@ class AudioService: | |||
| buffer.name = 'temp.mp3' | |||
| return Whisper(provider_service.provider).transcribe(buffer) | |||
| @@ -31,7 +31,7 @@ class HitTestingService: | |||
| model_credentials = LLMBuilder.get_model_credentials( | |||
| tenant_id=dataset.tenant_id, | |||
| model_provider=LLMBuilder.get_default_provider(dataset.tenant_id), | |||
| model_provider=LLMBuilder.get_default_provider(dataset.tenant_id, 'text-embedding-ada-002'), | |||
| model_name='text-embedding-ada-002' | |||
| ) | |||
| @@ -10,50 +10,40 @@ from models.provider import * | |||
| class ProviderService: | |||
| @staticmethod | |||
| def init_supported_provider(tenant, edition): | |||
| def init_supported_provider(tenant): | |||
| """Initialize the model provider, check whether the supported provider has a record""" | |||
| providers = Provider.query.filter_by(tenant_id=tenant.id).all() | |||
| need_init_provider_names = [ProviderName.OPENAI.value, ProviderName.AZURE_OPENAI.value, ProviderName.ANTHROPIC.value] | |||
| openai_provider_exists = False | |||
| azure_openai_provider_exists = False | |||
| # TODO: The cloud version needs to construct the data of the SYSTEM type | |||
| providers = db.session.query(Provider).filter( | |||
| Provider.tenant_id == tenant.id, | |||
| Provider.provider_type == ProviderType.CUSTOM.value, | |||
| Provider.provider_name.in_(need_init_provider_names) | |||
| ).all() | |||
| exists_provider_names = [] | |||
| for provider in providers: | |||
| if provider.provider_name == ProviderName.OPENAI.value and provider.provider_type == ProviderType.CUSTOM.value: | |||
| openai_provider_exists = True | |||
| if provider.provider_name == ProviderName.AZURE_OPENAI.value and provider.provider_type == ProviderType.CUSTOM.value: | |||
| azure_openai_provider_exists = True | |||
| exists_provider_names.append(provider.provider_name) | |||
| # Initialize the model provider, check whether the supported provider has a record | |||
| not_exists_provider_names = list(set(need_init_provider_names) - set(exists_provider_names)) | |||
| # Create default providers if they don't exist | |||
| if not openai_provider_exists: | |||
| openai_provider = Provider( | |||
| tenant_id=tenant.id, | |||
| provider_name=ProviderName.OPENAI.value, | |||
| provider_type=ProviderType.CUSTOM.value, | |||
| is_valid=False | |||
| ) | |||
| db.session.add(openai_provider) | |||
| if not azure_openai_provider_exists: | |||
| azure_openai_provider = Provider( | |||
| tenant_id=tenant.id, | |||
| provider_name=ProviderName.AZURE_OPENAI.value, | |||
| provider_type=ProviderType.CUSTOM.value, | |||
| is_valid=False | |||
| ) | |||
| db.session.add(azure_openai_provider) | |||
| if not_exists_provider_names: | |||
| # Initialize the model provider, check whether the supported provider has a record | |||
| for provider_name in not_exists_provider_names: | |||
| provider = Provider( | |||
| tenant_id=tenant.id, | |||
| provider_name=provider_name, | |||
| provider_type=ProviderType.CUSTOM.value, | |||
| is_valid=False | |||
| ) | |||
| db.session.add(provider) | |||
| if not openai_provider_exists or not azure_openai_provider_exists: | |||
| db.session.commit() | |||
| @staticmethod | |||
| def get_obfuscated_api_key(tenant, provider_name: ProviderName): | |||
| def get_obfuscated_api_key(tenant, provider_name: ProviderName, only_custom: bool = False): | |||
| llm_provider_service = LLMProviderService(tenant.id, provider_name.value) | |||
| return llm_provider_service.get_provider_configs(obfuscated=True) | |||
| return llm_provider_service.get_provider_configs(obfuscated=True, only_custom=only_custom) | |||
| @staticmethod | |||
| def get_token_type(tenant, provider_name: ProviderName): | |||
| @@ -73,7 +63,7 @@ class ProviderService: | |||
| return llm_provider_service.get_encrypted_token(configs) | |||
| @staticmethod | |||
| def create_system_provider(tenant: Tenant, provider_name: str = ProviderName.OPENAI.value, | |||
| def create_system_provider(tenant: Tenant, provider_name: str = ProviderName.OPENAI.value, quota_limit: int = 200, | |||
| is_valid: bool = True): | |||
| if current_app.config['EDITION'] != 'CLOUD': | |||
| return | |||
| @@ -90,7 +80,7 @@ class ProviderService: | |||
| provider_name=provider_name, | |||
| provider_type=ProviderType.SYSTEM.value, | |||
| quota_type=ProviderQuotaType.TRIAL.value, | |||
| quota_limit=200, | |||
| quota_limit=quota_limit, | |||
| encrypted_config='', | |||
| is_valid=is_valid, | |||
| ) | |||
| @@ -1,6 +1,6 @@ | |||
| from extensions.ext_database import db | |||
| from models.account import Tenant | |||
| from models.provider import Provider, ProviderType | |||
| from models.provider import Provider, ProviderType, ProviderName | |||
| class WorkspaceService: | |||
| @@ -33,7 +33,7 @@ class WorkspaceService: | |||
| if provider.is_valid and provider.encrypted_config: | |||
| custom_provider = provider | |||
| elif provider.provider_type == ProviderType.SYSTEM.value: | |||
| if provider.is_valid: | |||
| if provider.provider_name == ProviderName.OPENAI.value and provider.is_valid: | |||
| system_provider = provider | |||
| if system_provider and not custom_provider: | |||