| @@ -0,0 +1,21 @@ | |||
| import boto3 | |||
| from botocore.config import Config | |||
| def get_bedrock_client(service_name, credentials=None): | |||
| client_config = Config(region_name=credentials["aws_region"]) | |||
| aws_access_key_id = credentials["aws_access_key_id"] | |||
| aws_secret_access_key = credentials["aws_secret_access_key"] | |||
| if aws_access_key_id and aws_secret_access_key: | |||
| # use aksk to call bedrock | |||
| client = boto3.client( | |||
| service_name=service_name, | |||
| config=client_config, | |||
| aws_access_key_id=aws_access_key_id, | |||
| aws_secret_access_key=aws_secret_access_key, | |||
| ) | |||
| else: | |||
| # use iam without aksk to call | |||
| client = boto3.client(service_name=service_name, config=client_config) | |||
| return client | |||
| @@ -40,6 +40,7 @@ from core.model_runtime.errors.invoke import ( | |||
| ) | |||
| from core.model_runtime.errors.validate import CredentialsValidateFailedError | |||
| from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel | |||
| from core.model_runtime.model_providers.bedrock.get_bedrock_client import get_bedrock_client | |||
| logger = logging.getLogger(__name__) | |||
| ANTHROPIC_BLOCK_MODE_PROMPT = """You should always follow the instructions and output a valid {{block}} object. | |||
| @@ -173,13 +174,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel): | |||
| :param stream: is stream response | |||
| :return: full response or stream response chunk generator result | |||
| """ | |||
| bedrock_client = boto3.client( | |||
| service_name="bedrock-runtime", | |||
| aws_access_key_id=credentials.get("aws_access_key_id"), | |||
| aws_secret_access_key=credentials.get("aws_secret_access_key"), | |||
| region_name=credentials["aws_region"], | |||
| ) | |||
| bedrock_client = get_bedrock_client("bedrock-runtime", credentials) | |||
| system, prompt_message_dicts = self._convert_converse_prompt_messages(prompt_messages) | |||
| inference_config, additional_model_fields = self._convert_converse_api_model_parameters(model_parameters, stop) | |||
| @@ -1,8 +1,5 @@ | |||
| from typing import Optional | |||
| import boto3 | |||
| from botocore.config import Config | |||
| from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult | |||
| from core.model_runtime.errors.invoke import ( | |||
| InvokeAuthorizationError, | |||
| @@ -14,6 +11,7 @@ from core.model_runtime.errors.invoke import ( | |||
| ) | |||
| from core.model_runtime.errors.validate import CredentialsValidateFailedError | |||
| from core.model_runtime.model_providers.__base.rerank_model import RerankModel | |||
| from core.model_runtime.model_providers.bedrock.get_bedrock_client import get_bedrock_client | |||
| class BedrockRerankModel(RerankModel): | |||
| @@ -48,13 +46,7 @@ class BedrockRerankModel(RerankModel): | |||
| return RerankResult(model=model, docs=docs) | |||
| # initialize client | |||
| client_config = Config(region_name=credentials["aws_region"]) | |||
| bedrock_runtime = boto3.client( | |||
| service_name="bedrock-agent-runtime", | |||
| config=client_config, | |||
| aws_access_key_id=credentials.get("aws_access_key_id", ""), | |||
| aws_secret_access_key=credentials.get("aws_secret_access_key"), | |||
| ) | |||
| bedrock_runtime = get_bedrock_client("bedrock-agent-runtime", credentials) | |||
| queries = [{"type": "TEXT", "textQuery": {"text": query}}] | |||
| text_sources = [] | |||
| for text in docs: | |||
| @@ -3,8 +3,6 @@ import logging | |||
| import time | |||
| from typing import Optional | |||
| import boto3 | |||
| from botocore.config import Config | |||
| from botocore.exceptions import ( | |||
| ClientError, | |||
| EndpointConnectionError, | |||
| @@ -25,6 +23,7 @@ from core.model_runtime.errors.invoke import ( | |||
| InvokeServerUnavailableError, | |||
| ) | |||
| from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel | |||
| from core.model_runtime.model_providers.bedrock.get_bedrock_client import get_bedrock_client | |||
| logger = logging.getLogger(__name__) | |||
| @@ -48,14 +47,7 @@ class BedrockTextEmbeddingModel(TextEmbeddingModel): | |||
| :param input_type: input type | |||
| :return: embeddings result | |||
| """ | |||
| client_config = Config(region_name=credentials["aws_region"]) | |||
| bedrock_runtime = boto3.client( | |||
| service_name="bedrock-runtime", | |||
| config=client_config, | |||
| aws_access_key_id=credentials.get("aws_access_key_id"), | |||
| aws_secret_access_key=credentials.get("aws_secret_access_key"), | |||
| ) | |||
| bedrock_runtime = get_bedrock_client("bedrock-runtime", credentials) | |||
| embeddings = [] | |||
| token_usage = 0 | |||