Переглянути джерело

feat: Add support for embed file with AWS Bedrock Titan Model (#3377)

Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
tags/0.6.3
longzhihun 1 рік тому
джерело
коміт
f7a417fdb4
Аккаунт користувача з таким Email не знайдено

+ 1
- 0
api/core/model_runtime/model_providers/bedrock/bedrock.yaml Переглянути файл

en_US: https://console.aws.amazon.com/ en_US: https://console.aws.amazon.com/
supported_model_types: supported_model_types:
- llm - llm
- text-embedding
configurate_methods: configurate_methods:
- predefined-model - predefined-model
provider_credential_schema: provider_credential_schema:

+ 0
- 0
api/core/model_runtime/model_providers/bedrock/text_embedding/__init__.py Переглянути файл


+ 1
- 0
api/core/model_runtime/model_providers/bedrock/text_embedding/_position.yaml Переглянути файл

- amazon.titan-embed-text-v1

+ 8
- 0
api/core/model_runtime/model_providers/bedrock/text_embedding/amazon.titan-embed-text-v1.yaml Переглянути файл

model: amazon.titan-embed-text-v1
model_type: text-embedding
model_properties:
context_size: 8192
pricing:
input: '0.0001'
unit: '0.001'
currency: USD

+ 209
- 0
api/core/model_runtime/model_providers/bedrock/text_embedding/text_embedding.py Переглянути файл

import json
import time
from typing import Optional

import boto3
from botocore.config import Config
from botocore.exceptions import (
ClientError,
EndpointConnectionError,
NoRegionError,
ServiceNotInRegionError,
UnknownServiceError,
)

from core.model_runtime.entities.model_entities import PriceType
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
from core.model_runtime.errors.invoke import (
InvokeAuthorizationError,
InvokeBadRequestError,
InvokeConnectionError,
InvokeError,
InvokeRateLimitError,
InvokeServerUnavailableError,
)
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel


class BedrockTextEmbeddingModel(TextEmbeddingModel):


def _invoke(self, model: str, credentials: dict,
texts: list[str], user: Optional[str] = None) \
-> TextEmbeddingResult:
"""
Invoke text embedding model

:param model: model name
:param credentials: model credentials
:param texts: texts to embed
:param user: unique user id
:return: embeddings result
"""
client_config = Config(
region_name=credentials["aws_region"]
)

bedrock_runtime = boto3.client(
service_name='bedrock-runtime',
config=client_config,
aws_access_key_id=credentials["aws_access_key_id"],
aws_secret_access_key=credentials["aws_secret_access_key"]
)

embeddings = []
token_usage = 0

model_prefix = model.split('.')[0]
if model_prefix == "amazon":
for text in texts:
body = {
"inputText": text,
}
response_body = self._invoke_bedrock_embedding(model, bedrock_runtime, body)
embeddings.extend([response_body.get('embedding')])
token_usage += response_body.get('inputTextTokenCount')
result = TextEmbeddingResult(
model=model,
embeddings=embeddings,
usage=self._calc_response_usage(
model=model,
credentials=credentials,
tokens=token_usage
)
)
else:
raise ValueError(f"Got unknown model prefix {model_prefix} when handling block response")

return result


def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
"""
Get number of tokens for given prompt messages

:param model: model name
:param credentials: model credentials
:param texts: texts to embed
:return:
"""
num_tokens = 0
for text in texts:
num_tokens += self._get_num_tokens_by_gpt2(text)
return num_tokens

def validate_credentials(self, model: str, credentials: dict) -> None:
"""
Validate model credentials

:param model: model name
:param credentials: model credentials
:return:
"""
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
"""
Map model invoke error to unified error
The key is the ermd = genai.GenerativeModel(model)ror type thrown to the caller
The value is the md = genai.GenerativeModel(model)error type thrown by the model,
which needs to be converted into a unified error type for the caller.

:return: Invoke emd = genai.GenerativeModel(model)rror mapping
"""
return {
InvokeConnectionError: [],
InvokeServerUnavailableError: [],
InvokeRateLimitError: [],
InvokeAuthorizationError: [],
InvokeBadRequestError: []
}
def _create_payload(self, model_prefix: str, texts: list[str], model_parameters: dict, stop: Optional[list[str]] = None, stream: bool = True):
"""
Create payload for bedrock api call depending on model provider
"""
payload = dict()

if model_prefix == "amazon":
payload['inputText'] = texts

def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
"""
Calculate response usage

:param model: model name
:param credentials: model credentials
:param tokens: input tokens
:return: usage
"""
# get input price info
input_price_info = self.get_price(
model=model,
credentials=credentials,
price_type=PriceType.INPUT,
tokens=tokens
)

# transform usage
usage = EmbeddingUsage(
tokens=tokens,
total_tokens=tokens,
unit_price=input_price_info.unit_price,
price_unit=input_price_info.unit,
total_price=input_price_info.total_amount,
currency=input_price_info.currency,
latency=time.perf_counter() - self.started_at
)

return usage
def _map_client_to_invoke_error(self, error_code: str, error_msg: str) -> type[InvokeError]:
"""
Map client error to invoke error

:param error_code: error code
:param error_msg: error message
:return: invoke error
"""

if error_code == "AccessDeniedException":
return InvokeAuthorizationError(error_msg)
elif error_code in ["ResourceNotFoundException", "ValidationException"]:
return InvokeBadRequestError(error_msg)
elif error_code in ["ThrottlingException", "ServiceQuotaExceededException"]:
return InvokeRateLimitError(error_msg)
elif error_code in ["ModelTimeoutException", "ModelErrorException", "InternalServerException", "ModelNotReadyException"]:
return InvokeServerUnavailableError(error_msg)
elif error_code == "ModelStreamErrorException":
return InvokeConnectionError(error_msg)

return InvokeError(error_msg)

def _invoke_bedrock_embedding(self, model: str, bedrock_runtime, body: dict, ):
accept = 'application/json'
content_type = 'application/json'
try:
response = bedrock_runtime.invoke_model(
body=json.dumps(body),
modelId=model,
accept=accept,
contentType=content_type
)
response_body = json.loads(response.get('body').read().decode('utf-8'))
return response_body
except ClientError as ex:
error_code = ex.response['Error']['Code']
full_error_msg = f"{error_code}: {ex.response['Error']['Message']}"
raise self._map_client_to_invoke_error(error_code, full_error_msg)
except (EndpointConnectionError, NoRegionError, ServiceNotInRegionError) as ex:
raise InvokeConnectionError(str(ex))

except UnknownServiceError as ex:
raise InvokeServerUnavailableError(str(ex))

except Exception as ex:
raise InvokeError(str(ex))

Завантаження…
Відмінити
Зберегти