浏览代码

Feat: support azure openai for temporary (#101)

tags/0.2.2
John Wang 2 年前
父节点
当前提交
f68b05d5ec
没有帐户链接到提交者的电子邮件

+ 5
- 0
api/config.py 查看文件

'PDF_PREVIEW': 'True', 'PDF_PREVIEW': 'True',
'LOG_LEVEL': 'INFO', 'LOG_LEVEL': 'INFO',
'DISABLE_PROVIDER_CONFIG_VALIDATION': 'False', 'DISABLE_PROVIDER_CONFIG_VALIDATION': 'False',
'DEFAULT_LLM_PROVIDER': 'openai'
} }




# You could disable it for compatibility with certain OpenAPI providers # You could disable it for compatibility with certain OpenAPI providers
self.DISABLE_PROVIDER_CONFIG_VALIDATION = get_bool_env('DISABLE_PROVIDER_CONFIG_VALIDATION') self.DISABLE_PROVIDER_CONFIG_VALIDATION = get_bool_env('DISABLE_PROVIDER_CONFIG_VALIDATION')


# For temp use only
# set default LLM provider, default is 'openai', support `azure_openai`
self.DEFAULT_LLM_PROVIDER = get_env('DEFAULT_LLM_PROVIDER')

class CloudEditionConfig(Config): class CloudEditionConfig(Config):


def __init__(self): def __init__(self):

+ 29
- 15
api/controllers/console/workspace/providers.py 查看文件



args = parser.parse_args() args = parser.parse_args()


if not args['token']:
raise ValueError('Token is empty')

try:
ProviderService.validate_provider_configs(
if args['token']:
try:
ProviderService.validate_provider_configs(
tenant=current_user.current_tenant,
provider_name=ProviderName(provider),
configs=args['token']
)
token_is_valid = True
except ValidateFailedError:
token_is_valid = False

base64_encrypted_token = ProviderService.get_encrypted_token(
tenant=current_user.current_tenant, tenant=current_user.current_tenant,
provider_name=ProviderName(provider), provider_name=ProviderName(provider),
configs=args['token'] configs=args['token']
) )
token_is_valid = True
except ValidateFailedError:
else:
base64_encrypted_token = None
token_is_valid = False token_is_valid = False


tenant = current_user.current_tenant tenant = current_user.current_tenant


base64_encrypted_token = ProviderService.get_encrypted_token(
tenant=current_user.current_tenant,
provider_name=ProviderName(provider),
configs=args['token']
)

provider_model = Provider.query.filter_by(tenant_id=tenant.id, provider_name=provider,
provider_type=ProviderType.CUSTOM.value).first()
provider_model = db.session.query(Provider).filter(
Provider.tenant_id == tenant.id,
Provider.provider_name == provider,
Provider.provider_type == ProviderType.CUSTOM.value
).first()


# Only allow updating token for CUSTOM provider type # Only allow updating token for CUSTOM provider type
if provider_model: if provider_model:
is_valid=token_is_valid) is_valid=token_is_valid)
db.session.add(provider_model) db.session.add(provider_model)


if provider_model.is_valid:
other_providers = db.session.query(Provider).filter(
Provider.tenant_id == tenant.id,
Provider.provider_name != provider,
Provider.provider_type == ProviderType.CUSTOM.value
).all()

for other_provider in other_providers:
other_provider.is_valid = False

db.session.commit() db.session.commit()


if provider in [ProviderName.ANTHROPIC.value, ProviderName.AZURE_OPENAI.value, ProviderName.COHERE.value, if provider in [ProviderName.ANTHROPIC.value, ProviderName.AZURE_OPENAI.value, ProviderName.COHERE.value,

+ 48
- 24
api/core/embedding/openai_embedding.py 查看文件



@retry(reraise=True, wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6)) @retry(reraise=True, wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
def get_embedding( def get_embedding(
text: str,
engine: Optional[str] = None,
openai_api_key: Optional[str] = None,
text: str,
engine: Optional[str] = None,
api_key: Optional[str] = None,
**kwargs
) -> List[float]: ) -> List[float]:
"""Get embedding. """Get embedding.




""" """
text = text.replace("\n", " ") text = text.replace("\n", " ")
return openai.Embedding.create(input=[text], engine=engine, api_key=openai_api_key)["data"][0]["embedding"]
return openai.Embedding.create(input=[text], engine=engine, api_key=api_key, **kwargs)["data"][0]["embedding"]




@retry(reraise=True, wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6)) @retry(reraise=True, wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
async def aget_embedding(text: str, engine: Optional[str] = None, openai_api_key: Optional[str] = None) -> List[float]:
async def aget_embedding(text: str, engine: Optional[str] = None, api_key: Optional[str] = None, **kwargs) -> List[
float]:
"""Asynchronously get embedding. """Asynchronously get embedding.


NOTE: Copied from OpenAI's embedding utils: NOTE: Copied from OpenAI's embedding utils:
# replace newlines, which can negatively affect performance. # replace newlines, which can negatively affect performance.
text = text.replace("\n", " ") text = text.replace("\n", " ")


return (await openai.Embedding.acreate(input=[text], engine=engine, api_key=openai_api_key))["data"][0][
return (await openai.Embedding.acreate(input=[text], engine=engine, api_key=api_key, **kwargs))["data"][0][
"embedding" "embedding"
] ]




@retry(reraise=True, wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6)) @retry(reraise=True, wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
def get_embeddings( def get_embeddings(
list_of_text: List[str],
engine: Optional[str] = None,
openai_api_key: Optional[str] = None
list_of_text: List[str],
engine: Optional[str] = None,
api_key: Optional[str] = None,
**kwargs
) -> List[List[float]]: ) -> List[List[float]]:
"""Get embeddings. """Get embeddings.


# replace newlines, which can negatively affect performance. # replace newlines, which can negatively affect performance.
list_of_text = [text.replace("\n", " ") for text in list_of_text] list_of_text = [text.replace("\n", " ") for text in list_of_text]


data = openai.Embedding.create(input=list_of_text, engine=engine, api_key=openai_api_key).data
data = openai.Embedding.create(input=list_of_text, engine=engine, api_key=api_key, **kwargs).data
data = sorted(data, key=lambda x: x["index"]) # maintain the same order as input. data = sorted(data, key=lambda x: x["index"]) # maintain the same order as input.
return [d["embedding"] for d in data] return [d["embedding"] for d in data]




@retry(reraise=True, wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6)) @retry(reraise=True, wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
async def aget_embeddings( async def aget_embeddings(
list_of_text: List[str], engine: Optional[str] = None, openai_api_key: Optional[str] = None
list_of_text: List[str], engine: Optional[str] = None, api_key: Optional[str] = None, **kwargs
) -> List[List[float]]: ) -> List[List[float]]:
"""Asynchronously get embeddings. """Asynchronously get embeddings.


# replace newlines, which can negatively affect performance. # replace newlines, which can negatively affect performance.
list_of_text = [text.replace("\n", " ") for text in list_of_text] list_of_text = [text.replace("\n", " ") for text in list_of_text]


data = (await openai.Embedding.acreate(input=list_of_text, engine=engine, api_key=openai_api_key)).data
data = (await openai.Embedding.acreate(input=list_of_text, engine=engine, api_key=api_key, **kwargs)).data
data = sorted(data, key=lambda x: x["index"]) # maintain the same order as input. data = sorted(data, key=lambda x: x["index"]) # maintain the same order as input.
return [d["embedding"] for d in data] return [d["embedding"] for d in data]


class OpenAIEmbedding(BaseEmbedding): class OpenAIEmbedding(BaseEmbedding):


def __init__( def __init__(
self,
mode: str = OpenAIEmbeddingMode.TEXT_SEARCH_MODE,
model: str = OpenAIEmbeddingModelType.TEXT_EMBED_ADA_002,
deployment_name: Optional[str] = None,
openai_api_key: Optional[str] = None,
**kwargs: Any,
self,
mode: str = OpenAIEmbeddingMode.TEXT_SEARCH_MODE,
model: str = OpenAIEmbeddingModelType.TEXT_EMBED_ADA_002,
deployment_name: Optional[str] = None,
openai_api_key: Optional[str] = None,
**kwargs: Any,
) -> None: ) -> None:
"""Init params.""" """Init params."""
super().__init__(**kwargs)
new_kwargs = {}

if 'embed_batch_size' in kwargs:
new_kwargs['embed_batch_size'] = kwargs['embed_batch_size']

if 'tokenizer' in kwargs:
new_kwargs['tokenizer'] = kwargs['tokenizer']

super().__init__(**new_kwargs)
self.mode = OpenAIEmbeddingMode(mode) self.mode = OpenAIEmbeddingMode(mode)
self.model = OpenAIEmbeddingModelType(model) self.model = OpenAIEmbeddingModelType(model)
self.deployment_name = deployment_name self.deployment_name = deployment_name
self.openai_api_key = openai_api_key self.openai_api_key = openai_api_key
self.openai_api_type = kwargs.get('openai_api_type')
self.openai_api_version = kwargs.get('openai_api_version')
self.openai_api_base = kwargs.get('openai_api_base')


@handle_llm_exceptions @handle_llm_exceptions
def _get_query_embedding(self, query: str) -> List[float]: def _get_query_embedding(self, query: str) -> List[float]:
if key not in _QUERY_MODE_MODEL_DICT: if key not in _QUERY_MODE_MODEL_DICT:
raise ValueError(f"Invalid mode, model combination: {key}") raise ValueError(f"Invalid mode, model combination: {key}")
engine = _QUERY_MODE_MODEL_DICT[key] engine = _QUERY_MODE_MODEL_DICT[key]
return get_embedding(query, engine=engine, openai_api_key=self.openai_api_key)
return get_embedding(query, engine=engine, api_key=self.openai_api_key,
api_type=self.openai_api_type, api_version=self.openai_api_version,
api_base=self.openai_api_base)


def _get_text_embedding(self, text: str) -> List[float]: def _get_text_embedding(self, text: str) -> List[float]:
"""Get text embedding.""" """Get text embedding."""
if key not in _TEXT_MODE_MODEL_DICT: if key not in _TEXT_MODE_MODEL_DICT:
raise ValueError(f"Invalid mode, model combination: {key}") raise ValueError(f"Invalid mode, model combination: {key}")
engine = _TEXT_MODE_MODEL_DICT[key] engine = _TEXT_MODE_MODEL_DICT[key]
return get_embedding(text, engine=engine, openai_api_key=self.openai_api_key)
return get_embedding(text, engine=engine, api_key=self.openai_api_key,
api_type=self.openai_api_type, api_version=self.openai_api_version,
api_base=self.openai_api_base)


async def _aget_text_embedding(self, text: str) -> List[float]: async def _aget_text_embedding(self, text: str) -> List[float]:
"""Asynchronously get text embedding.""" """Asynchronously get text embedding."""
if key not in _TEXT_MODE_MODEL_DICT: if key not in _TEXT_MODE_MODEL_DICT:
raise ValueError(f"Invalid mode, model combination: {key}") raise ValueError(f"Invalid mode, model combination: {key}")
engine = _TEXT_MODE_MODEL_DICT[key] engine = _TEXT_MODE_MODEL_DICT[key]
return await aget_embedding(text, engine=engine, openai_api_key=self.openai_api_key)
return await aget_embedding(text, engine=engine, api_key=self.openai_api_key,
api_type=self.openai_api_type, api_version=self.openai_api_version,
api_base=self.openai_api_base)


def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]: def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:
"""Get text embeddings. """Get text embeddings.
if key not in _TEXT_MODE_MODEL_DICT: if key not in _TEXT_MODE_MODEL_DICT:
raise ValueError(f"Invalid mode, model combination: {key}") raise ValueError(f"Invalid mode, model combination: {key}")
engine = _TEXT_MODE_MODEL_DICT[key] engine = _TEXT_MODE_MODEL_DICT[key]
embeddings = get_embeddings(texts, engine=engine, openai_api_key=self.openai_api_key)
embeddings = get_embeddings(texts, engine=engine, api_key=self.openai_api_key,
api_type=self.openai_api_type, api_version=self.openai_api_version,
api_base=self.openai_api_base)
return embeddings return embeddings


async def _aget_text_embeddings(self, texts: List[str]) -> List[List[float]]: async def _aget_text_embeddings(self, texts: List[str]) -> List[List[float]]:
if key not in _TEXT_MODE_MODEL_DICT: if key not in _TEXT_MODE_MODEL_DICT:
raise ValueError(f"Invalid mode, model combination: {key}") raise ValueError(f"Invalid mode, model combination: {key}")
engine = _TEXT_MODE_MODEL_DICT[key] engine = _TEXT_MODE_MODEL_DICT[key]
embeddings = await aget_embeddings(texts, engine=engine, openai_api_key=self.openai_api_key)
embeddings = await aget_embeddings(texts, engine=engine, api_key=self.openai_api_key,
api_type=self.openai_api_type, api_version=self.openai_api_version,
api_base=self.openai_api_base)
return embeddings return embeddings

+ 3
- 0
api/core/index/index_builder.py 查看文件

max_chunk_overlap=20 max_chunk_overlap=20
) )


provider = LLMBuilder.get_default_provider(tenant_id)

model_credentials = LLMBuilder.get_model_credentials( model_credentials = LLMBuilder.get_model_credentials(
tenant_id=tenant_id, tenant_id=tenant_id,
model_provider=provider,
model_name='text-embedding-ada-002' model_name='text-embedding-ada-002'
) )



+ 34
- 9
api/core/llm/llm_builder.py 查看文件

from langchain.llms.fake import FakeListLLM from langchain.llms.fake import FakeListLLM


from core.constant import llm_constant from core.constant import llm_constant
from core.llm.error import ProviderTokenNotInitError
from core.llm.provider.base import BaseProvider
from core.llm.provider.llm_provider_service import LLMProviderService from core.llm.provider.llm_provider_service import LLMProviderService
from core.llm.streamable_azure_chat_open_ai import StreamableAzureChatOpenAI
from core.llm.streamable_azure_open_ai import StreamableAzureOpenAI
from core.llm.streamable_chat_open_ai import StreamableChatOpenAI from core.llm.streamable_chat_open_ai import StreamableChatOpenAI
from core.llm.streamable_open_ai import StreamableOpenAI from core.llm.streamable_open_ai import StreamableOpenAI
from models.provider import ProviderType




class LLMBuilder: class LLMBuilder:
if model_name == 'fake': if model_name == 'fake':
return FakeListLLM(responses=[]) return FakeListLLM(responses=[])


provider = cls.get_default_provider(tenant_id)

mode = cls.get_mode_by_model(model_name) mode = cls.get_mode_by_model(model_name)
if mode == 'chat': if mode == 'chat':
# llm_cls = StreamableAzureChatOpenAI
llm_cls = StreamableChatOpenAI
if provider == 'openai':
llm_cls = StreamableChatOpenAI
else:
llm_cls = StreamableAzureChatOpenAI
elif mode == 'completion': elif mode == 'completion':
llm_cls = StreamableOpenAI
if provider == 'openai':
llm_cls = StreamableOpenAI
else:
llm_cls = StreamableAzureOpenAI
else: else:
raise ValueError(f"model name {model_name} is not supported.") raise ValueError(f"model name {model_name} is not supported.")


model_credentials = cls.get_model_credentials(tenant_id, model_name)
model_credentials = cls.get_model_credentials(tenant_id, provider, model_name)


return llm_cls( return llm_cls(
model_name=model_name, model_name=model_name,
raise ValueError(f"model name {model_name} is not supported.") raise ValueError(f"model name {model_name} is not supported.")


@classmethod @classmethod
def get_model_credentials(cls, tenant_id: str, model_name: str) -> dict:
def get_model_credentials(cls, tenant_id: str, model_provider: str, model_name: str) -> dict:
""" """
Returns the API credentials for the given tenant_id and model_name, based on the model's provider. Returns the API credentials for the given tenant_id and model_name, based on the model's provider.
Raises an exception if the model_name is not found or if the provider is not found. Raises an exception if the model_name is not found or if the provider is not found.
""" """
if not model_name: if not model_name:
raise Exception('model name not found') raise Exception('model name not found')
#
# if model_name not in llm_constant.models:
# raise Exception('model {} not found'.format(model_name))


if model_name not in llm_constant.models:
raise Exception('model {} not found'.format(model_name))

model_provider = llm_constant.models[model_name]
# model_provider = llm_constant.models[model_name]


provider_service = LLMProviderService(tenant_id=tenant_id, provider_name=model_provider) provider_service = LLMProviderService(tenant_id=tenant_id, provider_name=model_provider)
return provider_service.get_credentials(model_name) return provider_service.get_credentials(model_name)

@classmethod
def get_default_provider(cls, tenant_id: str) -> str:
provider = BaseProvider.get_valid_provider(tenant_id)
if not provider:
raise ProviderTokenNotInitError()

if provider.provider_type == ProviderType.SYSTEM.value:
provider_name = 'openai'
else:
provider_name = provider.provider_name

return provider_name

+ 4
- 6
api/core/llm/provider/azure_provider.py 查看文件

""" """
Returns the API credentials for Azure OpenAI as a dictionary. Returns the API credentials for Azure OpenAI as a dictionary.
""" """
encrypted_config = self.get_provider_api_key(model_id=model_id)
config = json.loads(encrypted_config)
config = self.get_provider_api_key(model_id=model_id)
config['openai_api_type'] = 'azure' config['openai_api_type'] = 'azure'
config['deployment_name'] = model_id
config['deployment_name'] = model_id.replace('.', '')
return config return config


def get_provider_name(self): def get_provider_name(self):
""" """
try: try:
config = self.get_provider_api_key() config = self.get_provider_api_key()
config = json.loads(config)
except: except:
config = { config = {
'openai_api_type': 'azure', 'openai_api_type': 'azure',
'openai_api_version': '2023-03-15-preview', 'openai_api_version': '2023-03-15-preview',
'openai_api_base': 'https://foo.microsoft.com/bar',
'openai_api_base': 'https://<your-domain-prefix>.openai.azure.com/',
'openai_api_key': '' 'openai_api_key': ''
} }


config = { config = {
'openai_api_type': 'azure', 'openai_api_type': 'azure',
'openai_api_version': '2023-03-15-preview', 'openai_api_version': '2023-03-15-preview',
'openai_api_base': 'https://foo.microsoft.com/bar',
'openai_api_base': 'https://<your-domain-prefix>.openai.azure.com/',
'openai_api_key': '' 'openai_api_key': ''
} }



+ 22
- 10
api/core/llm/provider/base.py 查看文件

def __init__(self, tenant_id: str): def __init__(self, tenant_id: str):
self.tenant_id = tenant_id self.tenant_id = tenant_id


def get_provider_api_key(self, model_id: Optional[str] = None, prefer_custom: bool = True) -> str:
def get_provider_api_key(self, model_id: Optional[str] = None, prefer_custom: bool = True) -> Union[str | dict]:
""" """
Returns the decrypted API key for the given tenant_id and provider_name. Returns the decrypted API key for the given tenant_id and provider_name.
If the provider is of type SYSTEM and the quota is exceeded, raises a QuotaExceededError. If the provider is of type SYSTEM and the quota is exceeded, raises a QuotaExceededError.
Returns the Provider instance for the given tenant_id and provider_name. Returns the Provider instance for the given tenant_id and provider_name.
If both CUSTOM and System providers exist, the preferred provider will be returned based on the prefer_custom flag. If both CUSTOM and System providers exist, the preferred provider will be returned based on the prefer_custom flag.
""" """
providers = db.session.query(Provider).filter(
Provider.tenant_id == self.tenant_id,
Provider.provider_name == self.get_provider_name().value
).order_by(Provider.provider_type.desc() if prefer_custom else Provider.provider_type).all()
return BaseProvider.get_valid_provider(self.tenant_id, self.get_provider_name().value, prefer_custom)

@classmethod
def get_valid_provider(cls, tenant_id: str, provider_name: str = None, prefer_custom: bool = False) -> Optional[Provider]:
"""
Returns the Provider instance for the given tenant_id and provider_name.
If both CUSTOM and System providers exist, the preferred provider will be returned based on the prefer_custom flag.
"""
query = db.session.query(Provider).filter(
Provider.tenant_id == tenant_id
)

if provider_name:
query = query.filter(Provider.provider_name == provider_name)

providers = query.order_by(Provider.provider_type.desc() if prefer_custom else Provider.provider_type).all()


custom_provider = None custom_provider = None
system_provider = None system_provider = None


for provider in providers: for provider in providers:
if provider.provider_type == ProviderType.CUSTOM.value:
if provider.provider_type == ProviderType.CUSTOM.value and provider.is_valid and provider.encrypted_config:
custom_provider = provider custom_provider = provider
elif provider.provider_type == ProviderType.SYSTEM.value:
elif provider.provider_type == ProviderType.SYSTEM.value and provider.is_valid:
system_provider = provider system_provider = provider


if custom_provider and custom_provider.is_valid and custom_provider.encrypted_config:
if custom_provider:
return custom_provider return custom_provider
elif system_provider and system_provider.is_valid:
elif system_provider:
return system_provider return system_provider
else: else:
return None return None
try: try:
config = self.get_provider_api_key() config = self.get_provider_api_key()
except: except:
config = 'THIS-IS-A-MOCK-TOKEN'
config = ''


if obfuscated: if obfuscated:
return self.obfuscated_token(config) return self.obfuscated_token(config)

+ 40
- 2
api/core/llm/streamable_azure_chat_open_ai.py 查看文件

import requests
from langchain.schema import BaseMessage, ChatResult, LLMResult from langchain.schema import BaseMessage, ChatResult, LLMResult
from langchain.chat_models import AzureChatOpenAI from langchain.chat_models import AzureChatOpenAI
from typing import Optional, List
from typing import Optional, List, Dict, Any

from pydantic import root_validator


from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async




class StreamableAzureChatOpenAI(AzureChatOpenAI): class StreamableAzureChatOpenAI(AzureChatOpenAI):
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
try:
import openai
except ImportError:
raise ValueError(
"Could not import openai python package. "
"Please install it with `pip install openai`."
)
try:
values["client"] = openai.ChatCompletion
except AttributeError:
raise ValueError(
"`openai` has no `ChatCompletion` attribute, this is likely "
"due to an old version of the openai package. Try upgrading it "
"with `pip install --upgrade openai`."
)
if values["n"] < 1:
raise ValueError("n must be at least 1.")
if values["n"] > 1 and values["streaming"]:
raise ValueError("n must be 1 when streaming.")
return values

@property
def _default_params(self) -> Dict[str, Any]:
"""Get the default parameters for calling OpenAI API."""
return {
**super()._default_params,
"engine": self.deployment_name,
"api_type": self.openai_api_type,
"api_base": self.openai_api_base,
"api_version": self.openai_api_version,
"api_key": self.openai_api_key,
"organization": self.openai_organization if self.openai_organization else None,
}

def get_messages_tokens(self, messages: List[BaseMessage]) -> int: def get_messages_tokens(self, messages: List[BaseMessage]) -> int:
"""Get the number of tokens in a list of messages. """Get the number of tokens in a list of messages.



+ 64
- 0
api/core/llm/streamable_azure_open_ai.py 查看文件

import os

from langchain.llms import AzureOpenAI
from langchain.schema import LLMResult
from typing import Optional, List, Dict, Mapping, Any

from pydantic import root_validator

from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async


class StreamableAzureOpenAI(AzureOpenAI):
openai_api_type: str = "azure"
openai_api_version: str = ""

@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
try:
import openai

values["client"] = openai.Completion
except ImportError:
raise ValueError(
"Could not import openai python package. "
"Please install it with `pip install openai`."
)
if values["streaming"] and values["n"] > 1:
raise ValueError("Cannot stream results when n > 1.")
if values["streaming"] and values["best_of"] > 1:
raise ValueError("Cannot stream results when best_of > 1.")
return values

@property
def _invocation_params(self) -> Dict[str, Any]:
return {**super()._invocation_params, **{
"api_type": self.openai_api_type,
"api_base": self.openai_api_base,
"api_version": self.openai_api_version,
"api_key": self.openai_api_key,
"organization": self.openai_organization if self.openai_organization else None,
}}

@property
def _identifying_params(self) -> Mapping[str, Any]:
return {**super()._identifying_params, **{
"api_type": self.openai_api_type,
"api_base": self.openai_api_base,
"api_version": self.openai_api_version,
"api_key": self.openai_api_key,
"organization": self.openai_organization if self.openai_organization else None,
}}

@handle_llm_exceptions
def generate(
self, prompts: List[str], stop: Optional[List[str]] = None
) -> LLMResult:
return super().generate(prompts, stop)

@handle_llm_exceptions_async
async def agenerate(
self, prompts: List[str], stop: Optional[List[str]] = None
) -> LLMResult:
return await super().agenerate(prompts, stop)

+ 41
- 1
api/core/llm/streamable_chat_open_ai.py 查看文件

import os

from langchain.schema import BaseMessage, ChatResult, LLMResult from langchain.schema import BaseMessage, ChatResult, LLMResult
from langchain.chat_models import ChatOpenAI from langchain.chat_models import ChatOpenAI
from typing import Optional, List
from typing import Optional, List, Dict, Any

from pydantic import root_validator


from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async




class StreamableChatOpenAI(ChatOpenAI): class StreamableChatOpenAI(ChatOpenAI):


@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
try:
import openai
except ImportError:
raise ValueError(
"Could not import openai python package. "
"Please install it with `pip install openai`."
)
try:
values["client"] = openai.ChatCompletion
except AttributeError:
raise ValueError(
"`openai` has no `ChatCompletion` attribute, this is likely "
"due to an old version of the openai package. Try upgrading it "
"with `pip install --upgrade openai`."
)
if values["n"] < 1:
raise ValueError("n must be at least 1.")
if values["n"] > 1 and values["streaming"]:
raise ValueError("n must be 1 when streaming.")
return values

@property
def _default_params(self) -> Dict[str, Any]:
"""Get the default parameters for calling OpenAI API."""
return {
**super()._default_params,
"api_type": 'openai',
"api_base": os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1"),
"api_version": None,
"api_key": self.openai_api_key,
"organization": self.openai_organization if self.openai_organization else None,
}

def get_messages_tokens(self, messages: List[BaseMessage]) -> int: def get_messages_tokens(self, messages: List[BaseMessage]) -> int:
"""Get the number of tokens in a list of messages. """Get the number of tokens in a list of messages.



+ 43
- 1
api/core/llm/streamable_open_ai.py 查看文件

import os

from langchain.schema import LLMResult from langchain.schema import LLMResult
from typing import Optional, List
from typing import Optional, List, Dict, Any, Mapping
from langchain import OpenAI from langchain import OpenAI
from pydantic import root_validator


from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async




class StreamableOpenAI(OpenAI): class StreamableOpenAI(OpenAI):


@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
try:
import openai

values["client"] = openai.Completion
except ImportError:
raise ValueError(
"Could not import openai python package. "
"Please install it with `pip install openai`."
)
if values["streaming"] and values["n"] > 1:
raise ValueError("Cannot stream results when n > 1.")
if values["streaming"] and values["best_of"] > 1:
raise ValueError("Cannot stream results when best_of > 1.")
return values

@property
def _invocation_params(self) -> Dict[str, Any]:
return {**super()._invocation_params, **{
"api_type": 'openai',
"api_base": os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1"),
"api_version": None,
"api_key": self.openai_api_key,
"organization": self.openai_organization if self.openai_organization else None,
}}

@property
def _identifying_params(self) -> Mapping[str, Any]:
return {**super()._identifying_params, **{
"api_type": 'openai',
"api_base": os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1"),
"api_version": None,
"api_key": self.openai_api_key,
"organization": self.openai_organization if self.openai_organization else None,
}}


@handle_llm_exceptions @handle_llm_exceptions
def generate( def generate(
self, prompts: List[str], stop: Optional[List[str]] = None self, prompts: List[str], stop: Optional[List[str]] = None

+ 8
- 22
web/app/components/header/account-setting/provider-page/azure-provider/index.tsx 查看文件

const [token, setToken] = useState(provider.token as ProviderAzureToken || {}) const [token, setToken] = useState(provider.token as ProviderAzureToken || {})
const handleFocus = () => { const handleFocus = () => {
if (token === provider.token) { if (token === provider.token) {
token.azure_api_key = ''
token.openai_api_key = ''
setToken({...token}) setToken({...token})
onTokenChange({...token}) onTokenChange({...token})
} }
<div className='px-4 py-3'> <div className='px-4 py-3'>
<ProviderInput <ProviderInput
className='mb-4' className='mb-4'
name={t('common.provider.azure.resourceName')}
placeholder={t('common.provider.azure.resourceNamePlaceholder')}
value={token.azure_api_base}
onChange={(v) => handleChange('azure_api_base', v)}
/>
<ProviderInput
className='mb-4'
name={t('common.provider.azure.deploymentId')}
placeholder={t('common.provider.azure.deploymentIdPlaceholder')}
value={token.azure_api_type}
onChange={v => handleChange('azure_api_type', v)}
/>
<ProviderInput
className='mb-4'
name={t('common.provider.azure.apiVersion')}
placeholder={t('common.provider.azure.apiVersionPlaceholder')}
value={token.azure_api_version}
onChange={v => handleChange('azure_api_version', v)}
name={t('common.provider.azure.apiBase')}
placeholder={t('common.provider.azure.apiBasePlaceholder')}
value={token.openai_api_base}
onChange={(v) => handleChange('openai_api_base', v)}
/> />
<ProviderValidateTokenInput <ProviderValidateTokenInput
className='mb-4' className='mb-4'
name={t('common.provider.azure.apiKey')} name={t('common.provider.azure.apiKey')}
placeholder={t('common.provider.azure.apiKeyPlaceholder')} placeholder={t('common.provider.azure.apiKeyPlaceholder')}
value={token.azure_api_key}
onChange={v => handleChange('azure_api_key', v)}
value={token.openai_api_key}
onChange={v => handleChange('openai_api_key', v)}
onFocus={handleFocus} onFocus={handleFocus}
onValidatedStatus={onValidatedStatus} onValidatedStatus={onValidatedStatus}
providerName={provider.provider_name} providerName={provider.provider_name}
) )
} }


export default AzureProvider
export default AzureProvider

+ 3
- 3
web/app/components/header/account-setting/provider-page/provider-item/index.tsx 查看文件

const { notify } = useContext(ToastContext) const { notify } = useContext(ToastContext)
const [token, setToken] = useState<ProviderAzureToken | string>( const [token, setToken] = useState<ProviderAzureToken | string>(
provider.provider_name === 'azure_openai' provider.provider_name === 'azure_openai'
? { azure_api_base: '', azure_api_type: '', azure_api_version: '', azure_api_key: '' }
? { openai_api_base: '', openai_api_key: '' }
: '' : ''
) )
const id = `${provider.provider_name}-${provider.provider_type}` const id = `${provider.provider_name}-${provider.provider_type}`
const isOpen = id === activeId const isOpen = id === activeId
const providerKey = provider.provider_name === 'azure_openai' ? (provider.token as ProviderAzureToken)?.azure_api_key : provider.token
const providerKey = provider.provider_name === 'azure_openai' ? (provider.token as ProviderAzureToken)?.openai_api_key : provider.token
const comingSoon = false const comingSoon = false
const isValid = provider.is_valid const isValid = provider.is_valid


) )
} }


export default ProviderItem
export default ProviderItem

+ 2
- 6
web/i18n/lang/common.en.ts 查看文件

editKey: 'Edit', editKey: 'Edit',
invalidApiKey: 'Invalid API key', invalidApiKey: 'Invalid API key',
azure: { azure: {
resourceName: 'Resource Name',
resourceNamePlaceholder: 'The name of your Azure OpenAI Resource.',
deploymentId: 'Deployment ID',
deploymentIdPlaceholder: 'The deployment name you chose when you deployed the model.',
apiVersion: 'API Version',
apiVersionPlaceholder: 'The API version to use for this operation.',
apiBase: 'API Base',
apiBasePlaceholder: 'The API Base URL of your Azure OpenAI Resource.',
apiKey: 'API Key', apiKey: 'API Key',
apiKeyPlaceholder: 'Enter your API key here', apiKeyPlaceholder: 'Enter your API key here',
helpTip: 'Learn Azure OpenAI Service', helpTip: 'Learn Azure OpenAI Service',

+ 3
- 7
web/i18n/lang/common.zh.ts 查看文件

editKey: '编辑', editKey: '编辑',
invalidApiKey: '无效的 API 密钥', invalidApiKey: '无效的 API 密钥',
azure: { azure: {
resourceName: 'Resource Name',
resourceNamePlaceholder: 'The name of your Azure OpenAI Resource.',
deploymentId: 'Deployment ID',
deploymentIdPlaceholder: 'The deployment name you chose when you deployed the model.',
apiVersion: 'API Version',
apiVersionPlaceholder: 'The API version to use for this operation.',
apiBase: 'API Base',
apiBasePlaceholder: '输入您的 Azure OpenAI API Base 地址',
apiKey: 'API Key', apiKey: 'API Key',
apiKeyPlaceholder: 'Enter your API key here',
apiKeyPlaceholder: '输入你的 API 密钥',
helpTip: '了解 Azure OpenAI Service', helpTip: '了解 Azure OpenAI Service',
}, },
openaiHosted: { openaiHosted: {

+ 2
- 4
web/models/common.ts 查看文件

} }


export type ProviderAzureToken = { export type ProviderAzureToken = {
azure_api_base: string
azure_api_key: string
azure_api_type: string
azure_api_version: string
openai_api_base: string
openai_api_key: string
} }
export type Provider = { export type Provider = {
provider_name: string provider_name: string

正在加载...
取消
保存