| 'PDF_PREVIEW': 'True', | 'PDF_PREVIEW': 'True', | ||||
| 'LOG_LEVEL': 'INFO', | 'LOG_LEVEL': 'INFO', | ||||
| 'DISABLE_PROVIDER_CONFIG_VALIDATION': 'False', | 'DISABLE_PROVIDER_CONFIG_VALIDATION': 'False', | ||||
| 'DEFAULT_LLM_PROVIDER': 'openai' | |||||
| } | } | ||||
| # You could disable it for compatibility with certain OpenAPI providers | # You could disable it for compatibility with certain OpenAPI providers | ||||
| self.DISABLE_PROVIDER_CONFIG_VALIDATION = get_bool_env('DISABLE_PROVIDER_CONFIG_VALIDATION') | self.DISABLE_PROVIDER_CONFIG_VALIDATION = get_bool_env('DISABLE_PROVIDER_CONFIG_VALIDATION') | ||||
| # For temp use only | |||||
| # set default LLM provider, default is 'openai', support `azure_openai` | |||||
| self.DEFAULT_LLM_PROVIDER = get_env('DEFAULT_LLM_PROVIDER') | |||||
| class CloudEditionConfig(Config): | class CloudEditionConfig(Config): | ||||
| def __init__(self): | def __init__(self): |
| args = parser.parse_args() | args = parser.parse_args() | ||||
| if not args['token']: | |||||
| raise ValueError('Token is empty') | |||||
| try: | |||||
| ProviderService.validate_provider_configs( | |||||
| if args['token']: | |||||
| try: | |||||
| ProviderService.validate_provider_configs( | |||||
| tenant=current_user.current_tenant, | |||||
| provider_name=ProviderName(provider), | |||||
| configs=args['token'] | |||||
| ) | |||||
| token_is_valid = True | |||||
| except ValidateFailedError: | |||||
| token_is_valid = False | |||||
| base64_encrypted_token = ProviderService.get_encrypted_token( | |||||
| tenant=current_user.current_tenant, | tenant=current_user.current_tenant, | ||||
| provider_name=ProviderName(provider), | provider_name=ProviderName(provider), | ||||
| configs=args['token'] | configs=args['token'] | ||||
| ) | ) | ||||
| token_is_valid = True | |||||
| except ValidateFailedError: | |||||
| else: | |||||
| base64_encrypted_token = None | |||||
| token_is_valid = False | token_is_valid = False | ||||
| tenant = current_user.current_tenant | tenant = current_user.current_tenant | ||||
| base64_encrypted_token = ProviderService.get_encrypted_token( | |||||
| tenant=current_user.current_tenant, | |||||
| provider_name=ProviderName(provider), | |||||
| configs=args['token'] | |||||
| ) | |||||
| provider_model = Provider.query.filter_by(tenant_id=tenant.id, provider_name=provider, | |||||
| provider_type=ProviderType.CUSTOM.value).first() | |||||
| provider_model = db.session.query(Provider).filter( | |||||
| Provider.tenant_id == tenant.id, | |||||
| Provider.provider_name == provider, | |||||
| Provider.provider_type == ProviderType.CUSTOM.value | |||||
| ).first() | |||||
| # Only allow updating token for CUSTOM provider type | # Only allow updating token for CUSTOM provider type | ||||
| if provider_model: | if provider_model: | ||||
| is_valid=token_is_valid) | is_valid=token_is_valid) | ||||
| db.session.add(provider_model) | db.session.add(provider_model) | ||||
| if provider_model.is_valid: | |||||
| other_providers = db.session.query(Provider).filter( | |||||
| Provider.tenant_id == tenant.id, | |||||
| Provider.provider_name != provider, | |||||
| Provider.provider_type == ProviderType.CUSTOM.value | |||||
| ).all() | |||||
| for other_provider in other_providers: | |||||
| other_provider.is_valid = False | |||||
| db.session.commit() | db.session.commit() | ||||
| if provider in [ProviderName.ANTHROPIC.value, ProviderName.AZURE_OPENAI.value, ProviderName.COHERE.value, | if provider in [ProviderName.ANTHROPIC.value, ProviderName.AZURE_OPENAI.value, ProviderName.COHERE.value, |
| @retry(reraise=True, wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6)) | @retry(reraise=True, wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6)) | ||||
| def get_embedding( | def get_embedding( | ||||
| text: str, | |||||
| engine: Optional[str] = None, | |||||
| openai_api_key: Optional[str] = None, | |||||
| text: str, | |||||
| engine: Optional[str] = None, | |||||
| api_key: Optional[str] = None, | |||||
| **kwargs | |||||
| ) -> List[float]: | ) -> List[float]: | ||||
| """Get embedding. | """Get embedding. | ||||
| """ | """ | ||||
| text = text.replace("\n", " ") | text = text.replace("\n", " ") | ||||
| return openai.Embedding.create(input=[text], engine=engine, api_key=openai_api_key)["data"][0]["embedding"] | |||||
| return openai.Embedding.create(input=[text], engine=engine, api_key=api_key, **kwargs)["data"][0]["embedding"] | |||||
| @retry(reraise=True, wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6)) | @retry(reraise=True, wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6)) | ||||
| async def aget_embedding(text: str, engine: Optional[str] = None, openai_api_key: Optional[str] = None) -> List[float]: | |||||
| async def aget_embedding(text: str, engine: Optional[str] = None, api_key: Optional[str] = None, **kwargs) -> List[ | |||||
| float]: | |||||
| """Asynchronously get embedding. | """Asynchronously get embedding. | ||||
| NOTE: Copied from OpenAI's embedding utils: | NOTE: Copied from OpenAI's embedding utils: | ||||
| # replace newlines, which can negatively affect performance. | # replace newlines, which can negatively affect performance. | ||||
| text = text.replace("\n", " ") | text = text.replace("\n", " ") | ||||
| return (await openai.Embedding.acreate(input=[text], engine=engine, api_key=openai_api_key))["data"][0][ | |||||
| return (await openai.Embedding.acreate(input=[text], engine=engine, api_key=api_key, **kwargs))["data"][0][ | |||||
| "embedding" | "embedding" | ||||
| ] | ] | ||||
| @retry(reraise=True, wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6)) | @retry(reraise=True, wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6)) | ||||
| def get_embeddings( | def get_embeddings( | ||||
| list_of_text: List[str], | |||||
| engine: Optional[str] = None, | |||||
| openai_api_key: Optional[str] = None | |||||
| list_of_text: List[str], | |||||
| engine: Optional[str] = None, | |||||
| api_key: Optional[str] = None, | |||||
| **kwargs | |||||
| ) -> List[List[float]]: | ) -> List[List[float]]: | ||||
| """Get embeddings. | """Get embeddings. | ||||
| # replace newlines, which can negatively affect performance. | # replace newlines, which can negatively affect performance. | ||||
| list_of_text = [text.replace("\n", " ") for text in list_of_text] | list_of_text = [text.replace("\n", " ") for text in list_of_text] | ||||
| data = openai.Embedding.create(input=list_of_text, engine=engine, api_key=openai_api_key).data | |||||
| data = openai.Embedding.create(input=list_of_text, engine=engine, api_key=api_key, **kwargs).data | |||||
| data = sorted(data, key=lambda x: x["index"]) # maintain the same order as input. | data = sorted(data, key=lambda x: x["index"]) # maintain the same order as input. | ||||
| return [d["embedding"] for d in data] | return [d["embedding"] for d in data] | ||||
| @retry(reraise=True, wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6)) | @retry(reraise=True, wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6)) | ||||
| async def aget_embeddings( | async def aget_embeddings( | ||||
| list_of_text: List[str], engine: Optional[str] = None, openai_api_key: Optional[str] = None | |||||
| list_of_text: List[str], engine: Optional[str] = None, api_key: Optional[str] = None, **kwargs | |||||
| ) -> List[List[float]]: | ) -> List[List[float]]: | ||||
| """Asynchronously get embeddings. | """Asynchronously get embeddings. | ||||
| # replace newlines, which can negatively affect performance. | # replace newlines, which can negatively affect performance. | ||||
| list_of_text = [text.replace("\n", " ") for text in list_of_text] | list_of_text = [text.replace("\n", " ") for text in list_of_text] | ||||
| data = (await openai.Embedding.acreate(input=list_of_text, engine=engine, api_key=openai_api_key)).data | |||||
| data = (await openai.Embedding.acreate(input=list_of_text, engine=engine, api_key=api_key, **kwargs)).data | |||||
| data = sorted(data, key=lambda x: x["index"]) # maintain the same order as input. | data = sorted(data, key=lambda x: x["index"]) # maintain the same order as input. | ||||
| return [d["embedding"] for d in data] | return [d["embedding"] for d in data] | ||||
| class OpenAIEmbedding(BaseEmbedding): | class OpenAIEmbedding(BaseEmbedding): | ||||
| def __init__( | def __init__( | ||||
| self, | |||||
| mode: str = OpenAIEmbeddingMode.TEXT_SEARCH_MODE, | |||||
| model: str = OpenAIEmbeddingModelType.TEXT_EMBED_ADA_002, | |||||
| deployment_name: Optional[str] = None, | |||||
| openai_api_key: Optional[str] = None, | |||||
| **kwargs: Any, | |||||
| self, | |||||
| mode: str = OpenAIEmbeddingMode.TEXT_SEARCH_MODE, | |||||
| model: str = OpenAIEmbeddingModelType.TEXT_EMBED_ADA_002, | |||||
| deployment_name: Optional[str] = None, | |||||
| openai_api_key: Optional[str] = None, | |||||
| **kwargs: Any, | |||||
| ) -> None: | ) -> None: | ||||
| """Init params.""" | """Init params.""" | ||||
| super().__init__(**kwargs) | |||||
| new_kwargs = {} | |||||
| if 'embed_batch_size' in kwargs: | |||||
| new_kwargs['embed_batch_size'] = kwargs['embed_batch_size'] | |||||
| if 'tokenizer' in kwargs: | |||||
| new_kwargs['tokenizer'] = kwargs['tokenizer'] | |||||
| super().__init__(**new_kwargs) | |||||
| self.mode = OpenAIEmbeddingMode(mode) | self.mode = OpenAIEmbeddingMode(mode) | ||||
| self.model = OpenAIEmbeddingModelType(model) | self.model = OpenAIEmbeddingModelType(model) | ||||
| self.deployment_name = deployment_name | self.deployment_name = deployment_name | ||||
| self.openai_api_key = openai_api_key | self.openai_api_key = openai_api_key | ||||
| self.openai_api_type = kwargs.get('openai_api_type') | |||||
| self.openai_api_version = kwargs.get('openai_api_version') | |||||
| self.openai_api_base = kwargs.get('openai_api_base') | |||||
| @handle_llm_exceptions | @handle_llm_exceptions | ||||
| def _get_query_embedding(self, query: str) -> List[float]: | def _get_query_embedding(self, query: str) -> List[float]: | ||||
| if key not in _QUERY_MODE_MODEL_DICT: | if key not in _QUERY_MODE_MODEL_DICT: | ||||
| raise ValueError(f"Invalid mode, model combination: {key}") | raise ValueError(f"Invalid mode, model combination: {key}") | ||||
| engine = _QUERY_MODE_MODEL_DICT[key] | engine = _QUERY_MODE_MODEL_DICT[key] | ||||
| return get_embedding(query, engine=engine, openai_api_key=self.openai_api_key) | |||||
| return get_embedding(query, engine=engine, api_key=self.openai_api_key, | |||||
| api_type=self.openai_api_type, api_version=self.openai_api_version, | |||||
| api_base=self.openai_api_base) | |||||
| def _get_text_embedding(self, text: str) -> List[float]: | def _get_text_embedding(self, text: str) -> List[float]: | ||||
| """Get text embedding.""" | """Get text embedding.""" | ||||
| if key not in _TEXT_MODE_MODEL_DICT: | if key not in _TEXT_MODE_MODEL_DICT: | ||||
| raise ValueError(f"Invalid mode, model combination: {key}") | raise ValueError(f"Invalid mode, model combination: {key}") | ||||
| engine = _TEXT_MODE_MODEL_DICT[key] | engine = _TEXT_MODE_MODEL_DICT[key] | ||||
| return get_embedding(text, engine=engine, openai_api_key=self.openai_api_key) | |||||
| return get_embedding(text, engine=engine, api_key=self.openai_api_key, | |||||
| api_type=self.openai_api_type, api_version=self.openai_api_version, | |||||
| api_base=self.openai_api_base) | |||||
| async def _aget_text_embedding(self, text: str) -> List[float]: | async def _aget_text_embedding(self, text: str) -> List[float]: | ||||
| """Asynchronously get text embedding.""" | """Asynchronously get text embedding.""" | ||||
| if key not in _TEXT_MODE_MODEL_DICT: | if key not in _TEXT_MODE_MODEL_DICT: | ||||
| raise ValueError(f"Invalid mode, model combination: {key}") | raise ValueError(f"Invalid mode, model combination: {key}") | ||||
| engine = _TEXT_MODE_MODEL_DICT[key] | engine = _TEXT_MODE_MODEL_DICT[key] | ||||
| return await aget_embedding(text, engine=engine, openai_api_key=self.openai_api_key) | |||||
| return await aget_embedding(text, engine=engine, api_key=self.openai_api_key, | |||||
| api_type=self.openai_api_type, api_version=self.openai_api_version, | |||||
| api_base=self.openai_api_base) | |||||
| def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]: | def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]: | ||||
| """Get text embeddings. | """Get text embeddings. | ||||
| if key not in _TEXT_MODE_MODEL_DICT: | if key not in _TEXT_MODE_MODEL_DICT: | ||||
| raise ValueError(f"Invalid mode, model combination: {key}") | raise ValueError(f"Invalid mode, model combination: {key}") | ||||
| engine = _TEXT_MODE_MODEL_DICT[key] | engine = _TEXT_MODE_MODEL_DICT[key] | ||||
| embeddings = get_embeddings(texts, engine=engine, openai_api_key=self.openai_api_key) | |||||
| embeddings = get_embeddings(texts, engine=engine, api_key=self.openai_api_key, | |||||
| api_type=self.openai_api_type, api_version=self.openai_api_version, | |||||
| api_base=self.openai_api_base) | |||||
| return embeddings | return embeddings | ||||
| async def _aget_text_embeddings(self, texts: List[str]) -> List[List[float]]: | async def _aget_text_embeddings(self, texts: List[str]) -> List[List[float]]: | ||||
| if key not in _TEXT_MODE_MODEL_DICT: | if key not in _TEXT_MODE_MODEL_DICT: | ||||
| raise ValueError(f"Invalid mode, model combination: {key}") | raise ValueError(f"Invalid mode, model combination: {key}") | ||||
| engine = _TEXT_MODE_MODEL_DICT[key] | engine = _TEXT_MODE_MODEL_DICT[key] | ||||
| embeddings = await aget_embeddings(texts, engine=engine, openai_api_key=self.openai_api_key) | |||||
| embeddings = await aget_embeddings(texts, engine=engine, api_key=self.openai_api_key, | |||||
| api_type=self.openai_api_type, api_version=self.openai_api_version, | |||||
| api_base=self.openai_api_base) | |||||
| return embeddings | return embeddings |
| max_chunk_overlap=20 | max_chunk_overlap=20 | ||||
| ) | ) | ||||
| provider = LLMBuilder.get_default_provider(tenant_id) | |||||
| model_credentials = LLMBuilder.get_model_credentials( | model_credentials = LLMBuilder.get_model_credentials( | ||||
| tenant_id=tenant_id, | tenant_id=tenant_id, | ||||
| model_provider=provider, | |||||
| model_name='text-embedding-ada-002' | model_name='text-embedding-ada-002' | ||||
| ) | ) | ||||
| from langchain.llms.fake import FakeListLLM | from langchain.llms.fake import FakeListLLM | ||||
| from core.constant import llm_constant | from core.constant import llm_constant | ||||
| from core.llm.error import ProviderTokenNotInitError | |||||
| from core.llm.provider.base import BaseProvider | |||||
| from core.llm.provider.llm_provider_service import LLMProviderService | from core.llm.provider.llm_provider_service import LLMProviderService | ||||
| from core.llm.streamable_azure_chat_open_ai import StreamableAzureChatOpenAI | |||||
| from core.llm.streamable_azure_open_ai import StreamableAzureOpenAI | |||||
| from core.llm.streamable_chat_open_ai import StreamableChatOpenAI | from core.llm.streamable_chat_open_ai import StreamableChatOpenAI | ||||
| from core.llm.streamable_open_ai import StreamableOpenAI | from core.llm.streamable_open_ai import StreamableOpenAI | ||||
| from models.provider import ProviderType | |||||
| class LLMBuilder: | class LLMBuilder: | ||||
| if model_name == 'fake': | if model_name == 'fake': | ||||
| return FakeListLLM(responses=[]) | return FakeListLLM(responses=[]) | ||||
| provider = cls.get_default_provider(tenant_id) | |||||
| mode = cls.get_mode_by_model(model_name) | mode = cls.get_mode_by_model(model_name) | ||||
| if mode == 'chat': | if mode == 'chat': | ||||
| # llm_cls = StreamableAzureChatOpenAI | |||||
| llm_cls = StreamableChatOpenAI | |||||
| if provider == 'openai': | |||||
| llm_cls = StreamableChatOpenAI | |||||
| else: | |||||
| llm_cls = StreamableAzureChatOpenAI | |||||
| elif mode == 'completion': | elif mode == 'completion': | ||||
| llm_cls = StreamableOpenAI | |||||
| if provider == 'openai': | |||||
| llm_cls = StreamableOpenAI | |||||
| else: | |||||
| llm_cls = StreamableAzureOpenAI | |||||
| else: | else: | ||||
| raise ValueError(f"model name {model_name} is not supported.") | raise ValueError(f"model name {model_name} is not supported.") | ||||
| model_credentials = cls.get_model_credentials(tenant_id, model_name) | |||||
| model_credentials = cls.get_model_credentials(tenant_id, provider, model_name) | |||||
| return llm_cls( | return llm_cls( | ||||
| model_name=model_name, | model_name=model_name, | ||||
| raise ValueError(f"model name {model_name} is not supported.") | raise ValueError(f"model name {model_name} is not supported.") | ||||
| @classmethod | @classmethod | ||||
| def get_model_credentials(cls, tenant_id: str, model_name: str) -> dict: | |||||
| def get_model_credentials(cls, tenant_id: str, model_provider: str, model_name: str) -> dict: | |||||
| """ | """ | ||||
| Returns the API credentials for the given tenant_id and model_name, based on the model's provider. | Returns the API credentials for the given tenant_id and model_name, based on the model's provider. | ||||
| Raises an exception if the model_name is not found or if the provider is not found. | Raises an exception if the model_name is not found or if the provider is not found. | ||||
| """ | """ | ||||
| if not model_name: | if not model_name: | ||||
| raise Exception('model name not found') | raise Exception('model name not found') | ||||
| # | |||||
| # if model_name not in llm_constant.models: | |||||
| # raise Exception('model {} not found'.format(model_name)) | |||||
| if model_name not in llm_constant.models: | |||||
| raise Exception('model {} not found'.format(model_name)) | |||||
| model_provider = llm_constant.models[model_name] | |||||
| # model_provider = llm_constant.models[model_name] | |||||
| provider_service = LLMProviderService(tenant_id=tenant_id, provider_name=model_provider) | provider_service = LLMProviderService(tenant_id=tenant_id, provider_name=model_provider) | ||||
| return provider_service.get_credentials(model_name) | return provider_service.get_credentials(model_name) | ||||
| @classmethod | |||||
| def get_default_provider(cls, tenant_id: str) -> str: | |||||
| provider = BaseProvider.get_valid_provider(tenant_id) | |||||
| if not provider: | |||||
| raise ProviderTokenNotInitError() | |||||
| if provider.provider_type == ProviderType.SYSTEM.value: | |||||
| provider_name = 'openai' | |||||
| else: | |||||
| provider_name = provider.provider_name | |||||
| return provider_name |
| """ | """ | ||||
| Returns the API credentials for Azure OpenAI as a dictionary. | Returns the API credentials for Azure OpenAI as a dictionary. | ||||
| """ | """ | ||||
| encrypted_config = self.get_provider_api_key(model_id=model_id) | |||||
| config = json.loads(encrypted_config) | |||||
| config = self.get_provider_api_key(model_id=model_id) | |||||
| config['openai_api_type'] = 'azure' | config['openai_api_type'] = 'azure' | ||||
| config['deployment_name'] = model_id | |||||
| config['deployment_name'] = model_id.replace('.', '') | |||||
| return config | return config | ||||
| def get_provider_name(self): | def get_provider_name(self): | ||||
| """ | """ | ||||
| try: | try: | ||||
| config = self.get_provider_api_key() | config = self.get_provider_api_key() | ||||
| config = json.loads(config) | |||||
| except: | except: | ||||
| config = { | config = { | ||||
| 'openai_api_type': 'azure', | 'openai_api_type': 'azure', | ||||
| 'openai_api_version': '2023-03-15-preview', | 'openai_api_version': '2023-03-15-preview', | ||||
| 'openai_api_base': 'https://foo.microsoft.com/bar', | |||||
| 'openai_api_base': 'https://<your-domain-prefix>.openai.azure.com/', | |||||
| 'openai_api_key': '' | 'openai_api_key': '' | ||||
| } | } | ||||
| config = { | config = { | ||||
| 'openai_api_type': 'azure', | 'openai_api_type': 'azure', | ||||
| 'openai_api_version': '2023-03-15-preview', | 'openai_api_version': '2023-03-15-preview', | ||||
| 'openai_api_base': 'https://foo.microsoft.com/bar', | |||||
| 'openai_api_base': 'https://<your-domain-prefix>.openai.azure.com/', | |||||
| 'openai_api_key': '' | 'openai_api_key': '' | ||||
| } | } | ||||
| def __init__(self, tenant_id: str): | def __init__(self, tenant_id: str): | ||||
| self.tenant_id = tenant_id | self.tenant_id = tenant_id | ||||
| def get_provider_api_key(self, model_id: Optional[str] = None, prefer_custom: bool = True) -> str: | |||||
| def get_provider_api_key(self, model_id: Optional[str] = None, prefer_custom: bool = True) -> Union[str | dict]: | |||||
| """ | """ | ||||
| Returns the decrypted API key for the given tenant_id and provider_name. | Returns the decrypted API key for the given tenant_id and provider_name. | ||||
| If the provider is of type SYSTEM and the quota is exceeded, raises a QuotaExceededError. | If the provider is of type SYSTEM and the quota is exceeded, raises a QuotaExceededError. | ||||
| Returns the Provider instance for the given tenant_id and provider_name. | Returns the Provider instance for the given tenant_id and provider_name. | ||||
| If both CUSTOM and System providers exist, the preferred provider will be returned based on the prefer_custom flag. | If both CUSTOM and System providers exist, the preferred provider will be returned based on the prefer_custom flag. | ||||
| """ | """ | ||||
| providers = db.session.query(Provider).filter( | |||||
| Provider.tenant_id == self.tenant_id, | |||||
| Provider.provider_name == self.get_provider_name().value | |||||
| ).order_by(Provider.provider_type.desc() if prefer_custom else Provider.provider_type).all() | |||||
| return BaseProvider.get_valid_provider(self.tenant_id, self.get_provider_name().value, prefer_custom) | |||||
| @classmethod | |||||
| def get_valid_provider(cls, tenant_id: str, provider_name: str = None, prefer_custom: bool = False) -> Optional[Provider]: | |||||
| """ | |||||
| Returns the Provider instance for the given tenant_id and provider_name. | |||||
| If both CUSTOM and System providers exist, the preferred provider will be returned based on the prefer_custom flag. | |||||
| """ | |||||
| query = db.session.query(Provider).filter( | |||||
| Provider.tenant_id == tenant_id | |||||
| ) | |||||
| if provider_name: | |||||
| query = query.filter(Provider.provider_name == provider_name) | |||||
| providers = query.order_by(Provider.provider_type.desc() if prefer_custom else Provider.provider_type).all() | |||||
| custom_provider = None | custom_provider = None | ||||
| system_provider = None | system_provider = None | ||||
| for provider in providers: | for provider in providers: | ||||
| if provider.provider_type == ProviderType.CUSTOM.value: | |||||
| if provider.provider_type == ProviderType.CUSTOM.value and provider.is_valid and provider.encrypted_config: | |||||
| custom_provider = provider | custom_provider = provider | ||||
| elif provider.provider_type == ProviderType.SYSTEM.value: | |||||
| elif provider.provider_type == ProviderType.SYSTEM.value and provider.is_valid: | |||||
| system_provider = provider | system_provider = provider | ||||
| if custom_provider and custom_provider.is_valid and custom_provider.encrypted_config: | |||||
| if custom_provider: | |||||
| return custom_provider | return custom_provider | ||||
| elif system_provider and system_provider.is_valid: | |||||
| elif system_provider: | |||||
| return system_provider | return system_provider | ||||
| else: | else: | ||||
| return None | return None | ||||
| try: | try: | ||||
| config = self.get_provider_api_key() | config = self.get_provider_api_key() | ||||
| except: | except: | ||||
| config = 'THIS-IS-A-MOCK-TOKEN' | |||||
| config = '' | |||||
| if obfuscated: | if obfuscated: | ||||
| return self.obfuscated_token(config) | return self.obfuscated_token(config) |
| import requests | |||||
| from langchain.schema import BaseMessage, ChatResult, LLMResult | from langchain.schema import BaseMessage, ChatResult, LLMResult | ||||
| from langchain.chat_models import AzureChatOpenAI | from langchain.chat_models import AzureChatOpenAI | ||||
| from typing import Optional, List | |||||
| from typing import Optional, List, Dict, Any | |||||
| from pydantic import root_validator | |||||
| from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async | from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async | ||||
| class StreamableAzureChatOpenAI(AzureChatOpenAI): | class StreamableAzureChatOpenAI(AzureChatOpenAI): | ||||
| @root_validator() | |||||
| def validate_environment(cls, values: Dict) -> Dict: | |||||
| """Validate that api key and python package exists in environment.""" | |||||
| try: | |||||
| import openai | |||||
| except ImportError: | |||||
| raise ValueError( | |||||
| "Could not import openai python package. " | |||||
| "Please install it with `pip install openai`." | |||||
| ) | |||||
| try: | |||||
| values["client"] = openai.ChatCompletion | |||||
| except AttributeError: | |||||
| raise ValueError( | |||||
| "`openai` has no `ChatCompletion` attribute, this is likely " | |||||
| "due to an old version of the openai package. Try upgrading it " | |||||
| "with `pip install --upgrade openai`." | |||||
| ) | |||||
| if values["n"] < 1: | |||||
| raise ValueError("n must be at least 1.") | |||||
| if values["n"] > 1 and values["streaming"]: | |||||
| raise ValueError("n must be 1 when streaming.") | |||||
| return values | |||||
| @property | |||||
| def _default_params(self) -> Dict[str, Any]: | |||||
| """Get the default parameters for calling OpenAI API.""" | |||||
| return { | |||||
| **super()._default_params, | |||||
| "engine": self.deployment_name, | |||||
| "api_type": self.openai_api_type, | |||||
| "api_base": self.openai_api_base, | |||||
| "api_version": self.openai_api_version, | |||||
| "api_key": self.openai_api_key, | |||||
| "organization": self.openai_organization if self.openai_organization else None, | |||||
| } | |||||
| def get_messages_tokens(self, messages: List[BaseMessage]) -> int: | def get_messages_tokens(self, messages: List[BaseMessage]) -> int: | ||||
| """Get the number of tokens in a list of messages. | """Get the number of tokens in a list of messages. | ||||
| import os | |||||
| from langchain.llms import AzureOpenAI | |||||
| from langchain.schema import LLMResult | |||||
| from typing import Optional, List, Dict, Mapping, Any | |||||
| from pydantic import root_validator | |||||
| from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async | |||||
| class StreamableAzureOpenAI(AzureOpenAI): | |||||
| openai_api_type: str = "azure" | |||||
| openai_api_version: str = "" | |||||
| @root_validator() | |||||
| def validate_environment(cls, values: Dict) -> Dict: | |||||
| """Validate that api key and python package exists in environment.""" | |||||
| try: | |||||
| import openai | |||||
| values["client"] = openai.Completion | |||||
| except ImportError: | |||||
| raise ValueError( | |||||
| "Could not import openai python package. " | |||||
| "Please install it with `pip install openai`." | |||||
| ) | |||||
| if values["streaming"] and values["n"] > 1: | |||||
| raise ValueError("Cannot stream results when n > 1.") | |||||
| if values["streaming"] and values["best_of"] > 1: | |||||
| raise ValueError("Cannot stream results when best_of > 1.") | |||||
| return values | |||||
| @property | |||||
| def _invocation_params(self) -> Dict[str, Any]: | |||||
| return {**super()._invocation_params, **{ | |||||
| "api_type": self.openai_api_type, | |||||
| "api_base": self.openai_api_base, | |||||
| "api_version": self.openai_api_version, | |||||
| "api_key": self.openai_api_key, | |||||
| "organization": self.openai_organization if self.openai_organization else None, | |||||
| }} | |||||
| @property | |||||
| def _identifying_params(self) -> Mapping[str, Any]: | |||||
| return {**super()._identifying_params, **{ | |||||
| "api_type": self.openai_api_type, | |||||
| "api_base": self.openai_api_base, | |||||
| "api_version": self.openai_api_version, | |||||
| "api_key": self.openai_api_key, | |||||
| "organization": self.openai_organization if self.openai_organization else None, | |||||
| }} | |||||
| @handle_llm_exceptions | |||||
| def generate( | |||||
| self, prompts: List[str], stop: Optional[List[str]] = None | |||||
| ) -> LLMResult: | |||||
| return super().generate(prompts, stop) | |||||
| @handle_llm_exceptions_async | |||||
| async def agenerate( | |||||
| self, prompts: List[str], stop: Optional[List[str]] = None | |||||
| ) -> LLMResult: | |||||
| return await super().agenerate(prompts, stop) |
| import os | |||||
| from langchain.schema import BaseMessage, ChatResult, LLMResult | from langchain.schema import BaseMessage, ChatResult, LLMResult | ||||
| from langchain.chat_models import ChatOpenAI | from langchain.chat_models import ChatOpenAI | ||||
| from typing import Optional, List | |||||
| from typing import Optional, List, Dict, Any | |||||
| from pydantic import root_validator | |||||
| from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async | from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async | ||||
| class StreamableChatOpenAI(ChatOpenAI): | class StreamableChatOpenAI(ChatOpenAI): | ||||
| @root_validator() | |||||
| def validate_environment(cls, values: Dict) -> Dict: | |||||
| """Validate that api key and python package exists in environment.""" | |||||
| try: | |||||
| import openai | |||||
| except ImportError: | |||||
| raise ValueError( | |||||
| "Could not import openai python package. " | |||||
| "Please install it with `pip install openai`." | |||||
| ) | |||||
| try: | |||||
| values["client"] = openai.ChatCompletion | |||||
| except AttributeError: | |||||
| raise ValueError( | |||||
| "`openai` has no `ChatCompletion` attribute, this is likely " | |||||
| "due to an old version of the openai package. Try upgrading it " | |||||
| "with `pip install --upgrade openai`." | |||||
| ) | |||||
| if values["n"] < 1: | |||||
| raise ValueError("n must be at least 1.") | |||||
| if values["n"] > 1 and values["streaming"]: | |||||
| raise ValueError("n must be 1 when streaming.") | |||||
| return values | |||||
| @property | |||||
| def _default_params(self) -> Dict[str, Any]: | |||||
| """Get the default parameters for calling OpenAI API.""" | |||||
| return { | |||||
| **super()._default_params, | |||||
| "api_type": 'openai', | |||||
| "api_base": os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1"), | |||||
| "api_version": None, | |||||
| "api_key": self.openai_api_key, | |||||
| "organization": self.openai_organization if self.openai_organization else None, | |||||
| } | |||||
| def get_messages_tokens(self, messages: List[BaseMessage]) -> int: | def get_messages_tokens(self, messages: List[BaseMessage]) -> int: | ||||
| """Get the number of tokens in a list of messages. | """Get the number of tokens in a list of messages. | ||||
| import os | |||||
| from langchain.schema import LLMResult | from langchain.schema import LLMResult | ||||
| from typing import Optional, List | |||||
| from typing import Optional, List, Dict, Any, Mapping | |||||
| from langchain import OpenAI | from langchain import OpenAI | ||||
| from pydantic import root_validator | |||||
| from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async | from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async | ||||
| class StreamableOpenAI(OpenAI): | class StreamableOpenAI(OpenAI): | ||||
| @root_validator() | |||||
| def validate_environment(cls, values: Dict) -> Dict: | |||||
| """Validate that api key and python package exists in environment.""" | |||||
| try: | |||||
| import openai | |||||
| values["client"] = openai.Completion | |||||
| except ImportError: | |||||
| raise ValueError( | |||||
| "Could not import openai python package. " | |||||
| "Please install it with `pip install openai`." | |||||
| ) | |||||
| if values["streaming"] and values["n"] > 1: | |||||
| raise ValueError("Cannot stream results when n > 1.") | |||||
| if values["streaming"] and values["best_of"] > 1: | |||||
| raise ValueError("Cannot stream results when best_of > 1.") | |||||
| return values | |||||
| @property | |||||
| def _invocation_params(self) -> Dict[str, Any]: | |||||
| return {**super()._invocation_params, **{ | |||||
| "api_type": 'openai', | |||||
| "api_base": os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1"), | |||||
| "api_version": None, | |||||
| "api_key": self.openai_api_key, | |||||
| "organization": self.openai_organization if self.openai_organization else None, | |||||
| }} | |||||
| @property | |||||
| def _identifying_params(self) -> Mapping[str, Any]: | |||||
| return {**super()._identifying_params, **{ | |||||
| "api_type": 'openai', | |||||
| "api_base": os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1"), | |||||
| "api_version": None, | |||||
| "api_key": self.openai_api_key, | |||||
| "organization": self.openai_organization if self.openai_organization else None, | |||||
| }} | |||||
| @handle_llm_exceptions | @handle_llm_exceptions | ||||
| def generate( | def generate( | ||||
| self, prompts: List[str], stop: Optional[List[str]] = None | self, prompts: List[str], stop: Optional[List[str]] = None |
| const [token, setToken] = useState(provider.token as ProviderAzureToken || {}) | const [token, setToken] = useState(provider.token as ProviderAzureToken || {}) | ||||
| const handleFocus = () => { | const handleFocus = () => { | ||||
| if (token === provider.token) { | if (token === provider.token) { | ||||
| token.azure_api_key = '' | |||||
| token.openai_api_key = '' | |||||
| setToken({...token}) | setToken({...token}) | ||||
| onTokenChange({...token}) | onTokenChange({...token}) | ||||
| } | } | ||||
| <div className='px-4 py-3'> | <div className='px-4 py-3'> | ||||
| <ProviderInput | <ProviderInput | ||||
| className='mb-4' | className='mb-4' | ||||
| name={t('common.provider.azure.resourceName')} | |||||
| placeholder={t('common.provider.azure.resourceNamePlaceholder')} | |||||
| value={token.azure_api_base} | |||||
| onChange={(v) => handleChange('azure_api_base', v)} | |||||
| /> | |||||
| <ProviderInput | |||||
| className='mb-4' | |||||
| name={t('common.provider.azure.deploymentId')} | |||||
| placeholder={t('common.provider.azure.deploymentIdPlaceholder')} | |||||
| value={token.azure_api_type} | |||||
| onChange={v => handleChange('azure_api_type', v)} | |||||
| /> | |||||
| <ProviderInput | |||||
| className='mb-4' | |||||
| name={t('common.provider.azure.apiVersion')} | |||||
| placeholder={t('common.provider.azure.apiVersionPlaceholder')} | |||||
| value={token.azure_api_version} | |||||
| onChange={v => handleChange('azure_api_version', v)} | |||||
| name={t('common.provider.azure.apiBase')} | |||||
| placeholder={t('common.provider.azure.apiBasePlaceholder')} | |||||
| value={token.openai_api_base} | |||||
| onChange={(v) => handleChange('openai_api_base', v)} | |||||
| /> | /> | ||||
| <ProviderValidateTokenInput | <ProviderValidateTokenInput | ||||
| className='mb-4' | className='mb-4' | ||||
| name={t('common.provider.azure.apiKey')} | name={t('common.provider.azure.apiKey')} | ||||
| placeholder={t('common.provider.azure.apiKeyPlaceholder')} | placeholder={t('common.provider.azure.apiKeyPlaceholder')} | ||||
| value={token.azure_api_key} | |||||
| onChange={v => handleChange('azure_api_key', v)} | |||||
| value={token.openai_api_key} | |||||
| onChange={v => handleChange('openai_api_key', v)} | |||||
| onFocus={handleFocus} | onFocus={handleFocus} | ||||
| onValidatedStatus={onValidatedStatus} | onValidatedStatus={onValidatedStatus} | ||||
| providerName={provider.provider_name} | providerName={provider.provider_name} | ||||
| ) | ) | ||||
| } | } | ||||
| export default AzureProvider | |||||
| export default AzureProvider |
| const { notify } = useContext(ToastContext) | const { notify } = useContext(ToastContext) | ||||
| const [token, setToken] = useState<ProviderAzureToken | string>( | const [token, setToken] = useState<ProviderAzureToken | string>( | ||||
| provider.provider_name === 'azure_openai' | provider.provider_name === 'azure_openai' | ||||
| ? { azure_api_base: '', azure_api_type: '', azure_api_version: '', azure_api_key: '' } | |||||
| ? { openai_api_base: '', openai_api_key: '' } | |||||
| : '' | : '' | ||||
| ) | ) | ||||
| const id = `${provider.provider_name}-${provider.provider_type}` | const id = `${provider.provider_name}-${provider.provider_type}` | ||||
| const isOpen = id === activeId | const isOpen = id === activeId | ||||
| const providerKey = provider.provider_name === 'azure_openai' ? (provider.token as ProviderAzureToken)?.azure_api_key : provider.token | |||||
| const providerKey = provider.provider_name === 'azure_openai' ? (provider.token as ProviderAzureToken)?.openai_api_key : provider.token | |||||
| const comingSoon = false | const comingSoon = false | ||||
| const isValid = provider.is_valid | const isValid = provider.is_valid | ||||
| ) | ) | ||||
| } | } | ||||
| export default ProviderItem | |||||
| export default ProviderItem |
| editKey: 'Edit', | editKey: 'Edit', | ||||
| invalidApiKey: 'Invalid API key', | invalidApiKey: 'Invalid API key', | ||||
| azure: { | azure: { | ||||
| resourceName: 'Resource Name', | |||||
| resourceNamePlaceholder: 'The name of your Azure OpenAI Resource.', | |||||
| deploymentId: 'Deployment ID', | |||||
| deploymentIdPlaceholder: 'The deployment name you chose when you deployed the model.', | |||||
| apiVersion: 'API Version', | |||||
| apiVersionPlaceholder: 'The API version to use for this operation.', | |||||
| apiBase: 'API Base', | |||||
| apiBasePlaceholder: 'The API Base URL of your Azure OpenAI Resource.', | |||||
| apiKey: 'API Key', | apiKey: 'API Key', | ||||
| apiKeyPlaceholder: 'Enter your API key here', | apiKeyPlaceholder: 'Enter your API key here', | ||||
| helpTip: 'Learn Azure OpenAI Service', | helpTip: 'Learn Azure OpenAI Service', |
| editKey: '编辑', | editKey: '编辑', | ||||
| invalidApiKey: '无效的 API 密钥', | invalidApiKey: '无效的 API 密钥', | ||||
| azure: { | azure: { | ||||
| resourceName: 'Resource Name', | |||||
| resourceNamePlaceholder: 'The name of your Azure OpenAI Resource.', | |||||
| deploymentId: 'Deployment ID', | |||||
| deploymentIdPlaceholder: 'The deployment name you chose when you deployed the model.', | |||||
| apiVersion: 'API Version', | |||||
| apiVersionPlaceholder: 'The API version to use for this operation.', | |||||
| apiBase: 'API Base', | |||||
| apiBasePlaceholder: '输入您的 Azure OpenAI API Base 地址', | |||||
| apiKey: 'API Key', | apiKey: 'API Key', | ||||
| apiKeyPlaceholder: 'Enter your API key here', | |||||
| apiKeyPlaceholder: '输入你的 API 密钥', | |||||
| helpTip: '了解 Azure OpenAI Service', | helpTip: '了解 Azure OpenAI Service', | ||||
| }, | }, | ||||
| openaiHosted: { | openaiHosted: { |
| } | } | ||||
| export type ProviderAzureToken = { | export type ProviderAzureToken = { | ||||
| azure_api_base: string | |||||
| azure_api_key: string | |||||
| azure_api_type: string | |||||
| azure_api_version: string | |||||
| openai_api_base: string | |||||
| openai_api_key: string | |||||
| } | } | ||||
| export type Provider = { | export type Provider = { | ||||
| provider_name: string | provider_name: string |