| @@ -47,6 +47,7 @@ DEFAULTS = { | |||
| 'PDF_PREVIEW': 'True', | |||
| 'LOG_LEVEL': 'INFO', | |||
| 'DISABLE_PROVIDER_CONFIG_VALIDATION': 'False', | |||
| 'DEFAULT_LLM_PROVIDER': 'openai' | |||
| } | |||
| @@ -181,6 +182,10 @@ class Config: | |||
| # You could disable it for compatibility with certain OpenAPI providers | |||
| self.DISABLE_PROVIDER_CONFIG_VALIDATION = get_bool_env('DISABLE_PROVIDER_CONFIG_VALIDATION') | |||
| # For temp use only | |||
| # set default LLM provider, default is 'openai', support `azure_openai` | |||
| self.DEFAULT_LLM_PROVIDER = get_env('DEFAULT_LLM_PROVIDER') | |||
| class CloudEditionConfig(Config): | |||
| def __init__(self): | |||
| @@ -82,29 +82,33 @@ class ProviderTokenApi(Resource): | |||
| args = parser.parse_args() | |||
| if not args['token']: | |||
| raise ValueError('Token is empty') | |||
| try: | |||
| ProviderService.validate_provider_configs( | |||
| if args['token']: | |||
| try: | |||
| ProviderService.validate_provider_configs( | |||
| tenant=current_user.current_tenant, | |||
| provider_name=ProviderName(provider), | |||
| configs=args['token'] | |||
| ) | |||
| token_is_valid = True | |||
| except ValidateFailedError: | |||
| token_is_valid = False | |||
| base64_encrypted_token = ProviderService.get_encrypted_token( | |||
| tenant=current_user.current_tenant, | |||
| provider_name=ProviderName(provider), | |||
| configs=args['token'] | |||
| ) | |||
| token_is_valid = True | |||
| except ValidateFailedError: | |||
| else: | |||
| base64_encrypted_token = None | |||
| token_is_valid = False | |||
| tenant = current_user.current_tenant | |||
| base64_encrypted_token = ProviderService.get_encrypted_token( | |||
| tenant=current_user.current_tenant, | |||
| provider_name=ProviderName(provider), | |||
| configs=args['token'] | |||
| ) | |||
| provider_model = Provider.query.filter_by(tenant_id=tenant.id, provider_name=provider, | |||
| provider_type=ProviderType.CUSTOM.value).first() | |||
| provider_model = db.session.query(Provider).filter( | |||
| Provider.tenant_id == tenant.id, | |||
| Provider.provider_name == provider, | |||
| Provider.provider_type == ProviderType.CUSTOM.value | |||
| ).first() | |||
| # Only allow updating token for CUSTOM provider type | |||
| if provider_model: | |||
| @@ -117,6 +121,16 @@ class ProviderTokenApi(Resource): | |||
| is_valid=token_is_valid) | |||
| db.session.add(provider_model) | |||
| if provider_model.is_valid: | |||
| other_providers = db.session.query(Provider).filter( | |||
| Provider.tenant_id == tenant.id, | |||
| Provider.provider_name != provider, | |||
| Provider.provider_type == ProviderType.CUSTOM.value | |||
| ).all() | |||
| for other_provider in other_providers: | |||
| other_provider.is_valid = False | |||
| db.session.commit() | |||
| if provider in [ProviderName.ANTHROPIC.value, ProviderName.AZURE_OPENAI.value, ProviderName.COHERE.value, | |||
| @@ -11,9 +11,10 @@ from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_except | |||
| @retry(reraise=True, wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6)) | |||
| def get_embedding( | |||
| text: str, | |||
| engine: Optional[str] = None, | |||
| openai_api_key: Optional[str] = None, | |||
| text: str, | |||
| engine: Optional[str] = None, | |||
| api_key: Optional[str] = None, | |||
| **kwargs | |||
| ) -> List[float]: | |||
| """Get embedding. | |||
| @@ -25,11 +26,12 @@ def get_embedding( | |||
| """ | |||
| text = text.replace("\n", " ") | |||
| return openai.Embedding.create(input=[text], engine=engine, api_key=openai_api_key)["data"][0]["embedding"] | |||
| return openai.Embedding.create(input=[text], engine=engine, api_key=api_key, **kwargs)["data"][0]["embedding"] | |||
| @retry(reraise=True, wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6)) | |||
| async def aget_embedding(text: str, engine: Optional[str] = None, openai_api_key: Optional[str] = None) -> List[float]: | |||
| async def aget_embedding(text: str, engine: Optional[str] = None, api_key: Optional[str] = None, **kwargs) -> List[ | |||
| float]: | |||
| """Asynchronously get embedding. | |||
| NOTE: Copied from OpenAI's embedding utils: | |||
| @@ -42,16 +44,17 @@ async def aget_embedding(text: str, engine: Optional[str] = None, openai_api_key | |||
| # replace newlines, which can negatively affect performance. | |||
| text = text.replace("\n", " ") | |||
| return (await openai.Embedding.acreate(input=[text], engine=engine, api_key=openai_api_key))["data"][0][ | |||
| return (await openai.Embedding.acreate(input=[text], engine=engine, api_key=api_key, **kwargs))["data"][0][ | |||
| "embedding" | |||
| ] | |||
| @retry(reraise=True, wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6)) | |||
| def get_embeddings( | |||
| list_of_text: List[str], | |||
| engine: Optional[str] = None, | |||
| openai_api_key: Optional[str] = None | |||
| list_of_text: List[str], | |||
| engine: Optional[str] = None, | |||
| api_key: Optional[str] = None, | |||
| **kwargs | |||
| ) -> List[List[float]]: | |||
| """Get embeddings. | |||
| @@ -67,14 +70,14 @@ def get_embeddings( | |||
| # replace newlines, which can negatively affect performance. | |||
| list_of_text = [text.replace("\n", " ") for text in list_of_text] | |||
| data = openai.Embedding.create(input=list_of_text, engine=engine, api_key=openai_api_key).data | |||
| data = openai.Embedding.create(input=list_of_text, engine=engine, api_key=api_key, **kwargs).data | |||
| data = sorted(data, key=lambda x: x["index"]) # maintain the same order as input. | |||
| return [d["embedding"] for d in data] | |||
| @retry(reraise=True, wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6)) | |||
| async def aget_embeddings( | |||
| list_of_text: List[str], engine: Optional[str] = None, openai_api_key: Optional[str] = None | |||
| list_of_text: List[str], engine: Optional[str] = None, api_key: Optional[str] = None, **kwargs | |||
| ) -> List[List[float]]: | |||
| """Asynchronously get embeddings. | |||
| @@ -90,7 +93,7 @@ async def aget_embeddings( | |||
| # replace newlines, which can negatively affect performance. | |||
| list_of_text = [text.replace("\n", " ") for text in list_of_text] | |||
| data = (await openai.Embedding.acreate(input=list_of_text, engine=engine, api_key=openai_api_key)).data | |||
| data = (await openai.Embedding.acreate(input=list_of_text, engine=engine, api_key=api_key, **kwargs)).data | |||
| data = sorted(data, key=lambda x: x["index"]) # maintain the same order as input. | |||
| return [d["embedding"] for d in data] | |||
| @@ -98,19 +101,30 @@ async def aget_embeddings( | |||
| class OpenAIEmbedding(BaseEmbedding): | |||
| def __init__( | |||
| self, | |||
| mode: str = OpenAIEmbeddingMode.TEXT_SEARCH_MODE, | |||
| model: str = OpenAIEmbeddingModelType.TEXT_EMBED_ADA_002, | |||
| deployment_name: Optional[str] = None, | |||
| openai_api_key: Optional[str] = None, | |||
| **kwargs: Any, | |||
| self, | |||
| mode: str = OpenAIEmbeddingMode.TEXT_SEARCH_MODE, | |||
| model: str = OpenAIEmbeddingModelType.TEXT_EMBED_ADA_002, | |||
| deployment_name: Optional[str] = None, | |||
| openai_api_key: Optional[str] = None, | |||
| **kwargs: Any, | |||
| ) -> None: | |||
| """Init params.""" | |||
| super().__init__(**kwargs) | |||
| new_kwargs = {} | |||
| if 'embed_batch_size' in kwargs: | |||
| new_kwargs['embed_batch_size'] = kwargs['embed_batch_size'] | |||
| if 'tokenizer' in kwargs: | |||
| new_kwargs['tokenizer'] = kwargs['tokenizer'] | |||
| super().__init__(**new_kwargs) | |||
| self.mode = OpenAIEmbeddingMode(mode) | |||
| self.model = OpenAIEmbeddingModelType(model) | |||
| self.deployment_name = deployment_name | |||
| self.openai_api_key = openai_api_key | |||
| self.openai_api_type = kwargs.get('openai_api_type') | |||
| self.openai_api_version = kwargs.get('openai_api_version') | |||
| self.openai_api_base = kwargs.get('openai_api_base') | |||
| @handle_llm_exceptions | |||
| def _get_query_embedding(self, query: str) -> List[float]: | |||
| @@ -122,7 +136,9 @@ class OpenAIEmbedding(BaseEmbedding): | |||
| if key not in _QUERY_MODE_MODEL_DICT: | |||
| raise ValueError(f"Invalid mode, model combination: {key}") | |||
| engine = _QUERY_MODE_MODEL_DICT[key] | |||
| return get_embedding(query, engine=engine, openai_api_key=self.openai_api_key) | |||
| return get_embedding(query, engine=engine, api_key=self.openai_api_key, | |||
| api_type=self.openai_api_type, api_version=self.openai_api_version, | |||
| api_base=self.openai_api_base) | |||
| def _get_text_embedding(self, text: str) -> List[float]: | |||
| """Get text embedding.""" | |||
| @@ -133,7 +149,9 @@ class OpenAIEmbedding(BaseEmbedding): | |||
| if key not in _TEXT_MODE_MODEL_DICT: | |||
| raise ValueError(f"Invalid mode, model combination: {key}") | |||
| engine = _TEXT_MODE_MODEL_DICT[key] | |||
| return get_embedding(text, engine=engine, openai_api_key=self.openai_api_key) | |||
| return get_embedding(text, engine=engine, api_key=self.openai_api_key, | |||
| api_type=self.openai_api_type, api_version=self.openai_api_version, | |||
| api_base=self.openai_api_base) | |||
| async def _aget_text_embedding(self, text: str) -> List[float]: | |||
| """Asynchronously get text embedding.""" | |||
| @@ -144,7 +162,9 @@ class OpenAIEmbedding(BaseEmbedding): | |||
| if key not in _TEXT_MODE_MODEL_DICT: | |||
| raise ValueError(f"Invalid mode, model combination: {key}") | |||
| engine = _TEXT_MODE_MODEL_DICT[key] | |||
| return await aget_embedding(text, engine=engine, openai_api_key=self.openai_api_key) | |||
| return await aget_embedding(text, engine=engine, api_key=self.openai_api_key, | |||
| api_type=self.openai_api_type, api_version=self.openai_api_version, | |||
| api_base=self.openai_api_base) | |||
| def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]: | |||
| """Get text embeddings. | |||
| @@ -160,7 +180,9 @@ class OpenAIEmbedding(BaseEmbedding): | |||
| if key not in _TEXT_MODE_MODEL_DICT: | |||
| raise ValueError(f"Invalid mode, model combination: {key}") | |||
| engine = _TEXT_MODE_MODEL_DICT[key] | |||
| embeddings = get_embeddings(texts, engine=engine, openai_api_key=self.openai_api_key) | |||
| embeddings = get_embeddings(texts, engine=engine, api_key=self.openai_api_key, | |||
| api_type=self.openai_api_type, api_version=self.openai_api_version, | |||
| api_base=self.openai_api_base) | |||
| return embeddings | |||
| async def _aget_text_embeddings(self, texts: List[str]) -> List[List[float]]: | |||
| @@ -172,5 +194,7 @@ class OpenAIEmbedding(BaseEmbedding): | |||
| if key not in _TEXT_MODE_MODEL_DICT: | |||
| raise ValueError(f"Invalid mode, model combination: {key}") | |||
| engine = _TEXT_MODE_MODEL_DICT[key] | |||
| embeddings = await aget_embeddings(texts, engine=engine, openai_api_key=self.openai_api_key) | |||
| embeddings = await aget_embeddings(texts, engine=engine, api_key=self.openai_api_key, | |||
| api_type=self.openai_api_type, api_version=self.openai_api_version, | |||
| api_base=self.openai_api_base) | |||
| return embeddings | |||
| @@ -33,8 +33,11 @@ class IndexBuilder: | |||
| max_chunk_overlap=20 | |||
| ) | |||
| provider = LLMBuilder.get_default_provider(tenant_id) | |||
| model_credentials = LLMBuilder.get_model_credentials( | |||
| tenant_id=tenant_id, | |||
| model_provider=provider, | |||
| model_name='text-embedding-ada-002' | |||
| ) | |||
| @@ -4,9 +4,14 @@ from langchain.callbacks import CallbackManager | |||
| from langchain.llms.fake import FakeListLLM | |||
| from core.constant import llm_constant | |||
| from core.llm.error import ProviderTokenNotInitError | |||
| from core.llm.provider.base import BaseProvider | |||
| from core.llm.provider.llm_provider_service import LLMProviderService | |||
| from core.llm.streamable_azure_chat_open_ai import StreamableAzureChatOpenAI | |||
| from core.llm.streamable_azure_open_ai import StreamableAzureOpenAI | |||
| from core.llm.streamable_chat_open_ai import StreamableChatOpenAI | |||
| from core.llm.streamable_open_ai import StreamableOpenAI | |||
| from models.provider import ProviderType | |||
| class LLMBuilder: | |||
| @@ -31,16 +36,23 @@ class LLMBuilder: | |||
| if model_name == 'fake': | |||
| return FakeListLLM(responses=[]) | |||
| provider = cls.get_default_provider(tenant_id) | |||
| mode = cls.get_mode_by_model(model_name) | |||
| if mode == 'chat': | |||
| # llm_cls = StreamableAzureChatOpenAI | |||
| llm_cls = StreamableChatOpenAI | |||
| if provider == 'openai': | |||
| llm_cls = StreamableChatOpenAI | |||
| else: | |||
| llm_cls = StreamableAzureChatOpenAI | |||
| elif mode == 'completion': | |||
| llm_cls = StreamableOpenAI | |||
| if provider == 'openai': | |||
| llm_cls = StreamableOpenAI | |||
| else: | |||
| llm_cls = StreamableAzureOpenAI | |||
| else: | |||
| raise ValueError(f"model name {model_name} is not supported.") | |||
| model_credentials = cls.get_model_credentials(tenant_id, model_name) | |||
| model_credentials = cls.get_model_credentials(tenant_id, provider, model_name) | |||
| return llm_cls( | |||
| model_name=model_name, | |||
| @@ -86,18 +98,31 @@ class LLMBuilder: | |||
| raise ValueError(f"model name {model_name} is not supported.") | |||
| @classmethod | |||
| def get_model_credentials(cls, tenant_id: str, model_name: str) -> dict: | |||
| def get_model_credentials(cls, tenant_id: str, model_provider: str, model_name: str) -> dict: | |||
| """ | |||
| Returns the API credentials for the given tenant_id and model_name, based on the model's provider. | |||
| Raises an exception if the model_name is not found or if the provider is not found. | |||
| """ | |||
| if not model_name: | |||
| raise Exception('model name not found') | |||
| # | |||
| # if model_name not in llm_constant.models: | |||
| # raise Exception('model {} not found'.format(model_name)) | |||
| if model_name not in llm_constant.models: | |||
| raise Exception('model {} not found'.format(model_name)) | |||
| model_provider = llm_constant.models[model_name] | |||
| # model_provider = llm_constant.models[model_name] | |||
| provider_service = LLMProviderService(tenant_id=tenant_id, provider_name=model_provider) | |||
| return provider_service.get_credentials(model_name) | |||
| @classmethod | |||
| def get_default_provider(cls, tenant_id: str) -> str: | |||
| provider = BaseProvider.get_valid_provider(tenant_id) | |||
| if not provider: | |||
| raise ProviderTokenNotInitError() | |||
| if provider.provider_type == ProviderType.SYSTEM.value: | |||
| provider_name = 'openai' | |||
| else: | |||
| provider_name = provider.provider_name | |||
| return provider_name | |||
| @@ -36,10 +36,9 @@ class AzureProvider(BaseProvider): | |||
| """ | |||
| Returns the API credentials for Azure OpenAI as a dictionary. | |||
| """ | |||
| encrypted_config = self.get_provider_api_key(model_id=model_id) | |||
| config = json.loads(encrypted_config) | |||
| config = self.get_provider_api_key(model_id=model_id) | |||
| config['openai_api_type'] = 'azure' | |||
| config['deployment_name'] = model_id | |||
| config['deployment_name'] = model_id.replace('.', '') | |||
| return config | |||
| def get_provider_name(self): | |||
| @@ -51,12 +50,11 @@ class AzureProvider(BaseProvider): | |||
| """ | |||
| try: | |||
| config = self.get_provider_api_key() | |||
| config = json.loads(config) | |||
| except: | |||
| config = { | |||
| 'openai_api_type': 'azure', | |||
| 'openai_api_version': '2023-03-15-preview', | |||
| 'openai_api_base': 'https://foo.microsoft.com/bar', | |||
| 'openai_api_base': 'https://<your-domain-prefix>.openai.azure.com/', | |||
| 'openai_api_key': '' | |||
| } | |||
| @@ -65,7 +63,7 @@ class AzureProvider(BaseProvider): | |||
| config = { | |||
| 'openai_api_type': 'azure', | |||
| 'openai_api_version': '2023-03-15-preview', | |||
| 'openai_api_base': 'https://foo.microsoft.com/bar', | |||
| 'openai_api_base': 'https://<your-domain-prefix>.openai.azure.com/', | |||
| 'openai_api_key': '' | |||
| } | |||
| @@ -14,7 +14,7 @@ class BaseProvider(ABC): | |||
| def __init__(self, tenant_id: str): | |||
| self.tenant_id = tenant_id | |||
| def get_provider_api_key(self, model_id: Optional[str] = None, prefer_custom: bool = True) -> str: | |||
| def get_provider_api_key(self, model_id: Optional[str] = None, prefer_custom: bool = True) -> Union[str | dict]: | |||
| """ | |||
| Returns the decrypted API key for the given tenant_id and provider_name. | |||
| If the provider is of type SYSTEM and the quota is exceeded, raises a QuotaExceededError. | |||
| @@ -43,23 +43,35 @@ class BaseProvider(ABC): | |||
| Returns the Provider instance for the given tenant_id and provider_name. | |||
| If both CUSTOM and System providers exist, the preferred provider will be returned based on the prefer_custom flag. | |||
| """ | |||
| providers = db.session.query(Provider).filter( | |||
| Provider.tenant_id == self.tenant_id, | |||
| Provider.provider_name == self.get_provider_name().value | |||
| ).order_by(Provider.provider_type.desc() if prefer_custom else Provider.provider_type).all() | |||
| return BaseProvider.get_valid_provider(self.tenant_id, self.get_provider_name().value, prefer_custom) | |||
| @classmethod | |||
| def get_valid_provider(cls, tenant_id: str, provider_name: str = None, prefer_custom: bool = False) -> Optional[Provider]: | |||
| """ | |||
| Returns the Provider instance for the given tenant_id and provider_name. | |||
| If both CUSTOM and System providers exist, the preferred provider will be returned based on the prefer_custom flag. | |||
| """ | |||
| query = db.session.query(Provider).filter( | |||
| Provider.tenant_id == tenant_id | |||
| ) | |||
| if provider_name: | |||
| query = query.filter(Provider.provider_name == provider_name) | |||
| providers = query.order_by(Provider.provider_type.desc() if prefer_custom else Provider.provider_type).all() | |||
| custom_provider = None | |||
| system_provider = None | |||
| for provider in providers: | |||
| if provider.provider_type == ProviderType.CUSTOM.value: | |||
| if provider.provider_type == ProviderType.CUSTOM.value and provider.is_valid and provider.encrypted_config: | |||
| custom_provider = provider | |||
| elif provider.provider_type == ProviderType.SYSTEM.value: | |||
| elif provider.provider_type == ProviderType.SYSTEM.value and provider.is_valid: | |||
| system_provider = provider | |||
| if custom_provider and custom_provider.is_valid and custom_provider.encrypted_config: | |||
| if custom_provider: | |||
| return custom_provider | |||
| elif system_provider and system_provider.is_valid: | |||
| elif system_provider: | |||
| return system_provider | |||
| else: | |||
| return None | |||
| @@ -80,7 +92,7 @@ class BaseProvider(ABC): | |||
| try: | |||
| config = self.get_provider_api_key() | |||
| except: | |||
| config = 'THIS-IS-A-MOCK-TOKEN' | |||
| config = '' | |||
| if obfuscated: | |||
| return self.obfuscated_token(config) | |||
| @@ -1,12 +1,50 @@ | |||
| import requests | |||
| from langchain.schema import BaseMessage, ChatResult, LLMResult | |||
| from langchain.chat_models import AzureChatOpenAI | |||
| from typing import Optional, List | |||
| from typing import Optional, List, Dict, Any | |||
| from pydantic import root_validator | |||
| from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async | |||
| class StreamableAzureChatOpenAI(AzureChatOpenAI): | |||
| @root_validator() | |||
| def validate_environment(cls, values: Dict) -> Dict: | |||
| """Validate that api key and python package exists in environment.""" | |||
| try: | |||
| import openai | |||
| except ImportError: | |||
| raise ValueError( | |||
| "Could not import openai python package. " | |||
| "Please install it with `pip install openai`." | |||
| ) | |||
| try: | |||
| values["client"] = openai.ChatCompletion | |||
| except AttributeError: | |||
| raise ValueError( | |||
| "`openai` has no `ChatCompletion` attribute, this is likely " | |||
| "due to an old version of the openai package. Try upgrading it " | |||
| "with `pip install --upgrade openai`." | |||
| ) | |||
| if values["n"] < 1: | |||
| raise ValueError("n must be at least 1.") | |||
| if values["n"] > 1 and values["streaming"]: | |||
| raise ValueError("n must be 1 when streaming.") | |||
| return values | |||
| @property | |||
| def _default_params(self) -> Dict[str, Any]: | |||
| """Get the default parameters for calling OpenAI API.""" | |||
| return { | |||
| **super()._default_params, | |||
| "engine": self.deployment_name, | |||
| "api_type": self.openai_api_type, | |||
| "api_base": self.openai_api_base, | |||
| "api_version": self.openai_api_version, | |||
| "api_key": self.openai_api_key, | |||
| "organization": self.openai_organization if self.openai_organization else None, | |||
| } | |||
| def get_messages_tokens(self, messages: List[BaseMessage]) -> int: | |||
| """Get the number of tokens in a list of messages. | |||
| @@ -0,0 +1,64 @@ | |||
| import os | |||
| from langchain.llms import AzureOpenAI | |||
| from langchain.schema import LLMResult | |||
| from typing import Optional, List, Dict, Mapping, Any | |||
| from pydantic import root_validator | |||
| from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async | |||
| class StreamableAzureOpenAI(AzureOpenAI): | |||
| openai_api_type: str = "azure" | |||
| openai_api_version: str = "" | |||
| @root_validator() | |||
| def validate_environment(cls, values: Dict) -> Dict: | |||
| """Validate that api key and python package exists in environment.""" | |||
| try: | |||
| import openai | |||
| values["client"] = openai.Completion | |||
| except ImportError: | |||
| raise ValueError( | |||
| "Could not import openai python package. " | |||
| "Please install it with `pip install openai`." | |||
| ) | |||
| if values["streaming"] and values["n"] > 1: | |||
| raise ValueError("Cannot stream results when n > 1.") | |||
| if values["streaming"] and values["best_of"] > 1: | |||
| raise ValueError("Cannot stream results when best_of > 1.") | |||
| return values | |||
| @property | |||
| def _invocation_params(self) -> Dict[str, Any]: | |||
| return {**super()._invocation_params, **{ | |||
| "api_type": self.openai_api_type, | |||
| "api_base": self.openai_api_base, | |||
| "api_version": self.openai_api_version, | |||
| "api_key": self.openai_api_key, | |||
| "organization": self.openai_organization if self.openai_organization else None, | |||
| }} | |||
| @property | |||
| def _identifying_params(self) -> Mapping[str, Any]: | |||
| return {**super()._identifying_params, **{ | |||
| "api_type": self.openai_api_type, | |||
| "api_base": self.openai_api_base, | |||
| "api_version": self.openai_api_version, | |||
| "api_key": self.openai_api_key, | |||
| "organization": self.openai_organization if self.openai_organization else None, | |||
| }} | |||
| @handle_llm_exceptions | |||
| def generate( | |||
| self, prompts: List[str], stop: Optional[List[str]] = None | |||
| ) -> LLMResult: | |||
| return super().generate(prompts, stop) | |||
| @handle_llm_exceptions_async | |||
| async def agenerate( | |||
| self, prompts: List[str], stop: Optional[List[str]] = None | |||
| ) -> LLMResult: | |||
| return await super().agenerate(prompts, stop) | |||
| @@ -1,12 +1,52 @@ | |||
| import os | |||
| from langchain.schema import BaseMessage, ChatResult, LLMResult | |||
| from langchain.chat_models import ChatOpenAI | |||
| from typing import Optional, List | |||
| from typing import Optional, List, Dict, Any | |||
| from pydantic import root_validator | |||
| from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async | |||
| class StreamableChatOpenAI(ChatOpenAI): | |||
| @root_validator() | |||
| def validate_environment(cls, values: Dict) -> Dict: | |||
| """Validate that api key and python package exists in environment.""" | |||
| try: | |||
| import openai | |||
| except ImportError: | |||
| raise ValueError( | |||
| "Could not import openai python package. " | |||
| "Please install it with `pip install openai`." | |||
| ) | |||
| try: | |||
| values["client"] = openai.ChatCompletion | |||
| except AttributeError: | |||
| raise ValueError( | |||
| "`openai` has no `ChatCompletion` attribute, this is likely " | |||
| "due to an old version of the openai package. Try upgrading it " | |||
| "with `pip install --upgrade openai`." | |||
| ) | |||
| if values["n"] < 1: | |||
| raise ValueError("n must be at least 1.") | |||
| if values["n"] > 1 and values["streaming"]: | |||
| raise ValueError("n must be 1 when streaming.") | |||
| return values | |||
| @property | |||
| def _default_params(self) -> Dict[str, Any]: | |||
| """Get the default parameters for calling OpenAI API.""" | |||
| return { | |||
| **super()._default_params, | |||
| "api_type": 'openai', | |||
| "api_base": os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1"), | |||
| "api_version": None, | |||
| "api_key": self.openai_api_key, | |||
| "organization": self.openai_organization if self.openai_organization else None, | |||
| } | |||
| def get_messages_tokens(self, messages: List[BaseMessage]) -> int: | |||
| """Get the number of tokens in a list of messages. | |||
| @@ -1,12 +1,54 @@ | |||
| import os | |||
| from langchain.schema import LLMResult | |||
| from typing import Optional, List | |||
| from typing import Optional, List, Dict, Any, Mapping | |||
| from langchain import OpenAI | |||
| from pydantic import root_validator | |||
| from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async | |||
| class StreamableOpenAI(OpenAI): | |||
| @root_validator() | |||
| def validate_environment(cls, values: Dict) -> Dict: | |||
| """Validate that api key and python package exists in environment.""" | |||
| try: | |||
| import openai | |||
| values["client"] = openai.Completion | |||
| except ImportError: | |||
| raise ValueError( | |||
| "Could not import openai python package. " | |||
| "Please install it with `pip install openai`." | |||
| ) | |||
| if values["streaming"] and values["n"] > 1: | |||
| raise ValueError("Cannot stream results when n > 1.") | |||
| if values["streaming"] and values["best_of"] > 1: | |||
| raise ValueError("Cannot stream results when best_of > 1.") | |||
| return values | |||
| @property | |||
| def _invocation_params(self) -> Dict[str, Any]: | |||
| return {**super()._invocation_params, **{ | |||
| "api_type": 'openai', | |||
| "api_base": os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1"), | |||
| "api_version": None, | |||
| "api_key": self.openai_api_key, | |||
| "organization": self.openai_organization if self.openai_organization else None, | |||
| }} | |||
| @property | |||
| def _identifying_params(self) -> Mapping[str, Any]: | |||
| return {**super()._identifying_params, **{ | |||
| "api_type": 'openai', | |||
| "api_base": os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1"), | |||
| "api_version": None, | |||
| "api_key": self.openai_api_key, | |||
| "organization": self.openai_organization if self.openai_organization else None, | |||
| }} | |||
| @handle_llm_exceptions | |||
| def generate( | |||
| self, prompts: List[str], stop: Optional[List[str]] = None | |||
| @@ -20,7 +20,7 @@ const AzureProvider = ({ | |||
| const [token, setToken] = useState(provider.token as ProviderAzureToken || {}) | |||
| const handleFocus = () => { | |||
| if (token === provider.token) { | |||
| token.azure_api_key = '' | |||
| token.openai_api_key = '' | |||
| setToken({...token}) | |||
| onTokenChange({...token}) | |||
| } | |||
| @@ -35,31 +35,17 @@ const AzureProvider = ({ | |||
| <div className='px-4 py-3'> | |||
| <ProviderInput | |||
| className='mb-4' | |||
| name={t('common.provider.azure.resourceName')} | |||
| placeholder={t('common.provider.azure.resourceNamePlaceholder')} | |||
| value={token.azure_api_base} | |||
| onChange={(v) => handleChange('azure_api_base', v)} | |||
| /> | |||
| <ProviderInput | |||
| className='mb-4' | |||
| name={t('common.provider.azure.deploymentId')} | |||
| placeholder={t('common.provider.azure.deploymentIdPlaceholder')} | |||
| value={token.azure_api_type} | |||
| onChange={v => handleChange('azure_api_type', v)} | |||
| /> | |||
| <ProviderInput | |||
| className='mb-4' | |||
| name={t('common.provider.azure.apiVersion')} | |||
| placeholder={t('common.provider.azure.apiVersionPlaceholder')} | |||
| value={token.azure_api_version} | |||
| onChange={v => handleChange('azure_api_version', v)} | |||
| name={t('common.provider.azure.apiBase')} | |||
| placeholder={t('common.provider.azure.apiBasePlaceholder')} | |||
| value={token.openai_api_base} | |||
| onChange={(v) => handleChange('openai_api_base', v)} | |||
| /> | |||
| <ProviderValidateTokenInput | |||
| className='mb-4' | |||
| name={t('common.provider.azure.apiKey')} | |||
| placeholder={t('common.provider.azure.apiKeyPlaceholder')} | |||
| value={token.azure_api_key} | |||
| onChange={v => handleChange('azure_api_key', v)} | |||
| value={token.openai_api_key} | |||
| onChange={v => handleChange('openai_api_key', v)} | |||
| onFocus={handleFocus} | |||
| onValidatedStatus={onValidatedStatus} | |||
| providerName={provider.provider_name} | |||
| @@ -72,4 +58,4 @@ const AzureProvider = ({ | |||
| ) | |||
| } | |||
| export default AzureProvider | |||
| export default AzureProvider | |||
| @@ -33,12 +33,12 @@ const ProviderItem = ({ | |||
| const { notify } = useContext(ToastContext) | |||
| const [token, setToken] = useState<ProviderAzureToken | string>( | |||
| provider.provider_name === 'azure_openai' | |||
| ? { azure_api_base: '', azure_api_type: '', azure_api_version: '', azure_api_key: '' } | |||
| ? { openai_api_base: '', openai_api_key: '' } | |||
| : '' | |||
| ) | |||
| const id = `${provider.provider_name}-${provider.provider_type}` | |||
| const isOpen = id === activeId | |||
| const providerKey = provider.provider_name === 'azure_openai' ? (provider.token as ProviderAzureToken)?.azure_api_key : provider.token | |||
| const providerKey = provider.provider_name === 'azure_openai' ? (provider.token as ProviderAzureToken)?.openai_api_key : provider.token | |||
| const comingSoon = false | |||
| const isValid = provider.is_valid | |||
| @@ -135,4 +135,4 @@ const ProviderItem = ({ | |||
| ) | |||
| } | |||
| export default ProviderItem | |||
| export default ProviderItem | |||
| @@ -148,12 +148,8 @@ const translation = { | |||
| editKey: 'Edit', | |||
| invalidApiKey: 'Invalid API key', | |||
| azure: { | |||
| resourceName: 'Resource Name', | |||
| resourceNamePlaceholder: 'The name of your Azure OpenAI Resource.', | |||
| deploymentId: 'Deployment ID', | |||
| deploymentIdPlaceholder: 'The deployment name you chose when you deployed the model.', | |||
| apiVersion: 'API Version', | |||
| apiVersionPlaceholder: 'The API version to use for this operation.', | |||
| apiBase: 'API Base', | |||
| apiBasePlaceholder: 'The API Base URL of your Azure OpenAI Resource.', | |||
| apiKey: 'API Key', | |||
| apiKeyPlaceholder: 'Enter your API key here', | |||
| helpTip: 'Learn Azure OpenAI Service', | |||
| @@ -149,14 +149,10 @@ const translation = { | |||
| editKey: '编辑', | |||
| invalidApiKey: '无效的 API 密钥', | |||
| azure: { | |||
| resourceName: 'Resource Name', | |||
| resourceNamePlaceholder: 'The name of your Azure OpenAI Resource.', | |||
| deploymentId: 'Deployment ID', | |||
| deploymentIdPlaceholder: 'The deployment name you chose when you deployed the model.', | |||
| apiVersion: 'API Version', | |||
| apiVersionPlaceholder: 'The API version to use for this operation.', | |||
| apiBase: 'API Base', | |||
| apiBasePlaceholder: '输入您的 Azure OpenAI API Base 地址', | |||
| apiKey: 'API Key', | |||
| apiKeyPlaceholder: 'Enter your API key here', | |||
| apiKeyPlaceholder: '输入你的 API 密钥', | |||
| helpTip: '了解 Azure OpenAI Service', | |||
| }, | |||
| openaiHosted: { | |||
| @@ -55,10 +55,8 @@ export type Member = Pick<UserProfileResponse, 'id' | 'name' | 'email' | 'last_l | |||
| } | |||
| export type ProviderAzureToken = { | |||
| azure_api_base: string | |||
| azure_api_key: string | |||
| azure_api_type: string | |||
| azure_api_version: string | |||
| openai_api_base: string | |||
| openai_api_key: string | |||
| } | |||
| export type Provider = { | |||
| provider_name: string | |||