|
|
|
@@ -0,0 +1,209 @@ |
|
|
|
import json |
|
|
|
import time |
|
|
|
from typing import Optional |
|
|
|
|
|
|
|
import boto3 |
|
|
|
from botocore.config import Config |
|
|
|
from botocore.exceptions import ( |
|
|
|
ClientError, |
|
|
|
EndpointConnectionError, |
|
|
|
NoRegionError, |
|
|
|
ServiceNotInRegionError, |
|
|
|
UnknownServiceError, |
|
|
|
) |
|
|
|
|
|
|
|
from core.model_runtime.entities.model_entities import PriceType |
|
|
|
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult |
|
|
|
from core.model_runtime.errors.invoke import ( |
|
|
|
InvokeAuthorizationError, |
|
|
|
InvokeBadRequestError, |
|
|
|
InvokeConnectionError, |
|
|
|
InvokeError, |
|
|
|
InvokeRateLimitError, |
|
|
|
InvokeServerUnavailableError, |
|
|
|
) |
|
|
|
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel |
|
|
|
|
|
|
|
|
|
|
|
class BedrockTextEmbeddingModel(TextEmbeddingModel): |
|
|
|
|
|
|
|
|
|
|
|
def _invoke(self, model: str, credentials: dict, |
|
|
|
texts: list[str], user: Optional[str] = None) \ |
|
|
|
-> TextEmbeddingResult: |
|
|
|
""" |
|
|
|
Invoke text embedding model |
|
|
|
|
|
|
|
:param model: model name |
|
|
|
:param credentials: model credentials |
|
|
|
:param texts: texts to embed |
|
|
|
:param user: unique user id |
|
|
|
:return: embeddings result |
|
|
|
""" |
|
|
|
client_config = Config( |
|
|
|
region_name=credentials["aws_region"] |
|
|
|
) |
|
|
|
|
|
|
|
bedrock_runtime = boto3.client( |
|
|
|
service_name='bedrock-runtime', |
|
|
|
config=client_config, |
|
|
|
aws_access_key_id=credentials["aws_access_key_id"], |
|
|
|
aws_secret_access_key=credentials["aws_secret_access_key"] |
|
|
|
) |
|
|
|
|
|
|
|
embeddings = [] |
|
|
|
token_usage = 0 |
|
|
|
|
|
|
|
model_prefix = model.split('.')[0] |
|
|
|
if model_prefix == "amazon": |
|
|
|
for text in texts: |
|
|
|
body = { |
|
|
|
"inputText": text, |
|
|
|
} |
|
|
|
response_body = self._invoke_bedrock_embedding(model, bedrock_runtime, body) |
|
|
|
embeddings.extend([response_body.get('embedding')]) |
|
|
|
token_usage += response_body.get('inputTextTokenCount') |
|
|
|
result = TextEmbeddingResult( |
|
|
|
model=model, |
|
|
|
embeddings=embeddings, |
|
|
|
usage=self._calc_response_usage( |
|
|
|
model=model, |
|
|
|
credentials=credentials, |
|
|
|
tokens=token_usage |
|
|
|
) |
|
|
|
) |
|
|
|
else: |
|
|
|
raise ValueError(f"Got unknown model prefix {model_prefix} when handling block response") |
|
|
|
|
|
|
|
return result |
|
|
|
|
|
|
|
|
|
|
|
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: |
|
|
|
""" |
|
|
|
Get number of tokens for given prompt messages |
|
|
|
|
|
|
|
:param model: model name |
|
|
|
:param credentials: model credentials |
|
|
|
:param texts: texts to embed |
|
|
|
:return: |
|
|
|
""" |
|
|
|
num_tokens = 0 |
|
|
|
for text in texts: |
|
|
|
num_tokens += self._get_num_tokens_by_gpt2(text) |
|
|
|
return num_tokens |
|
|
|
|
|
|
|
def validate_credentials(self, model: str, credentials: dict) -> None: |
|
|
|
""" |
|
|
|
Validate model credentials |
|
|
|
|
|
|
|
:param model: model name |
|
|
|
:param credentials: model credentials |
|
|
|
:return: |
|
|
|
""" |
|
|
|
|
|
|
|
@property |
|
|
|
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: |
|
|
|
""" |
|
|
|
Map model invoke error to unified error |
|
|
|
The key is the ermd = genai.GenerativeModel(model)ror type thrown to the caller |
|
|
|
The value is the md = genai.GenerativeModel(model)error type thrown by the model, |
|
|
|
which needs to be converted into a unified error type for the caller. |
|
|
|
|
|
|
|
:return: Invoke emd = genai.GenerativeModel(model)rror mapping |
|
|
|
""" |
|
|
|
return { |
|
|
|
InvokeConnectionError: [], |
|
|
|
InvokeServerUnavailableError: [], |
|
|
|
InvokeRateLimitError: [], |
|
|
|
InvokeAuthorizationError: [], |
|
|
|
InvokeBadRequestError: [] |
|
|
|
} |
|
|
|
|
|
|
|
def _create_payload(self, model_prefix: str, texts: list[str], model_parameters: dict, stop: Optional[list[str]] = None, stream: bool = True): |
|
|
|
""" |
|
|
|
Create payload for bedrock api call depending on model provider |
|
|
|
""" |
|
|
|
payload = dict() |
|
|
|
|
|
|
|
if model_prefix == "amazon": |
|
|
|
payload['inputText'] = texts |
|
|
|
|
|
|
|
|
|
|
|
def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: |
|
|
|
""" |
|
|
|
Calculate response usage |
|
|
|
|
|
|
|
:param model: model name |
|
|
|
:param credentials: model credentials |
|
|
|
:param tokens: input tokens |
|
|
|
:return: usage |
|
|
|
""" |
|
|
|
# get input price info |
|
|
|
input_price_info = self.get_price( |
|
|
|
model=model, |
|
|
|
credentials=credentials, |
|
|
|
price_type=PriceType.INPUT, |
|
|
|
tokens=tokens |
|
|
|
) |
|
|
|
|
|
|
|
# transform usage |
|
|
|
usage = EmbeddingUsage( |
|
|
|
tokens=tokens, |
|
|
|
total_tokens=tokens, |
|
|
|
unit_price=input_price_info.unit_price, |
|
|
|
price_unit=input_price_info.unit, |
|
|
|
total_price=input_price_info.total_amount, |
|
|
|
currency=input_price_info.currency, |
|
|
|
latency=time.perf_counter() - self.started_at |
|
|
|
) |
|
|
|
|
|
|
|
return usage |
|
|
|
|
|
|
|
def _map_client_to_invoke_error(self, error_code: str, error_msg: str) -> type[InvokeError]: |
|
|
|
""" |
|
|
|
Map client error to invoke error |
|
|
|
|
|
|
|
:param error_code: error code |
|
|
|
:param error_msg: error message |
|
|
|
:return: invoke error |
|
|
|
""" |
|
|
|
|
|
|
|
if error_code == "AccessDeniedException": |
|
|
|
return InvokeAuthorizationError(error_msg) |
|
|
|
elif error_code in ["ResourceNotFoundException", "ValidationException"]: |
|
|
|
return InvokeBadRequestError(error_msg) |
|
|
|
elif error_code in ["ThrottlingException", "ServiceQuotaExceededException"]: |
|
|
|
return InvokeRateLimitError(error_msg) |
|
|
|
elif error_code in ["ModelTimeoutException", "ModelErrorException", "InternalServerException", "ModelNotReadyException"]: |
|
|
|
return InvokeServerUnavailableError(error_msg) |
|
|
|
elif error_code == "ModelStreamErrorException": |
|
|
|
return InvokeConnectionError(error_msg) |
|
|
|
|
|
|
|
return InvokeError(error_msg) |
|
|
|
|
|
|
|
|
|
|
|
def _invoke_bedrock_embedding(self, model: str, bedrock_runtime, body: dict, ): |
|
|
|
accept = 'application/json' |
|
|
|
content_type = 'application/json' |
|
|
|
try: |
|
|
|
response = bedrock_runtime.invoke_model( |
|
|
|
body=json.dumps(body), |
|
|
|
modelId=model, |
|
|
|
accept=accept, |
|
|
|
contentType=content_type |
|
|
|
) |
|
|
|
response_body = json.loads(response.get('body').read().decode('utf-8')) |
|
|
|
return response_body |
|
|
|
except ClientError as ex: |
|
|
|
error_code = ex.response['Error']['Code'] |
|
|
|
full_error_msg = f"{error_code}: {ex.response['Error']['Message']}" |
|
|
|
raise self._map_client_to_invoke_error(error_code, full_error_msg) |
|
|
|
|
|
|
|
except (EndpointConnectionError, NoRegionError, ServiceNotInRegionError) as ex: |
|
|
|
raise InvokeConnectionError(str(ex)) |
|
|
|
|
|
|
|
except UnknownServiceError as ex: |
|
|
|
raise InvokeServerUnavailableError(str(ex)) |
|
|
|
|
|
|
|
except Exception as ex: |
|
|
|
raise InvokeError(str(ex)) |