|
|
|
@@ -25,6 +25,7 @@ from botocore.exceptions import ( |
|
|
|
ServiceNotInRegionError, |
|
|
|
UnknownServiceError, |
|
|
|
) |
|
|
|
from cohere import ChatMessage |
|
|
|
|
|
|
|
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage |
|
|
|
from core.model_runtime.entities.message_entities import ( |
|
|
|
@@ -48,6 +49,7 @@ from core.model_runtime.errors.invoke import ( |
|
|
|
) |
|
|
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError |
|
|
|
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel |
|
|
|
from core.model_runtime.model_providers.cohere.llm.llm import CohereLargeLanguageModel |
|
|
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
@@ -75,8 +77,86 @@ class BedrockLargeLanguageModel(LargeLanguageModel): |
|
|
|
# invoke anthropic models via anthropic official SDK |
|
|
|
if "anthropic" in model: |
|
|
|
return self._generate_anthropic(model, credentials, prompt_messages, model_parameters, stop, stream, user) |
|
|
|
# invoke Cohere models via boto3 client |
|
|
|
if "cohere.command-r" in model: |
|
|
|
return self._generate_cohere_chat(model, credentials, prompt_messages, model_parameters, stop, stream, user, tools) |
|
|
|
# invoke other models via boto3 client |
|
|
|
return self._generate(model, credentials, prompt_messages, model_parameters, stop, stream, user) |
|
|
|
|
|
|
|
def _generate_cohere_chat( |
|
|
|
self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, |
|
|
|
stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None, |
|
|
|
tools: Optional[list[PromptMessageTool]] = None,) -> Union[LLMResult, Generator]: |
|
|
|
cohere_llm = CohereLargeLanguageModel() |
|
|
|
client_config = Config( |
|
|
|
region_name=credentials["aws_region"] |
|
|
|
) |
|
|
|
|
|
|
|
runtime_client = boto3.client( |
|
|
|
service_name='bedrock-runtime', |
|
|
|
config=client_config, |
|
|
|
aws_access_key_id=credentials["aws_access_key_id"], |
|
|
|
aws_secret_access_key=credentials["aws_secret_access_key"] |
|
|
|
) |
|
|
|
|
|
|
|
extra_model_kwargs = {} |
|
|
|
if stop: |
|
|
|
extra_model_kwargs['stop_sequences'] = stop |
|
|
|
|
|
|
|
if tools: |
|
|
|
tools = cohere_llm._convert_tools(tools) |
|
|
|
model_parameters['tools'] = tools |
|
|
|
|
|
|
|
message, chat_histories, tool_results \ |
|
|
|
= cohere_llm._convert_prompt_messages_to_message_and_chat_histories(prompt_messages) |
|
|
|
|
|
|
|
if tool_results: |
|
|
|
model_parameters['tool_results'] = tool_results |
|
|
|
|
|
|
|
payload = { |
|
|
|
**model_parameters, |
|
|
|
"message": message, |
|
|
|
"chat_history": chat_histories, |
|
|
|
} |
|
|
|
|
|
|
|
# need workaround for ai21 models which doesn't support streaming |
|
|
|
if stream: |
|
|
|
invoke = runtime_client.invoke_model_with_response_stream |
|
|
|
else: |
|
|
|
invoke = runtime_client.invoke_model |
|
|
|
|
|
|
|
def serialize(obj): |
|
|
|
if isinstance(obj, ChatMessage): |
|
|
|
return obj.__dict__ |
|
|
|
raise TypeError(f"Type {type(obj)} not serializable") |
|
|
|
|
|
|
|
try: |
|
|
|
body_jsonstr=json.dumps(payload, default=serialize) |
|
|
|
response = invoke( |
|
|
|
modelId=model, |
|
|
|
contentType="application/json", |
|
|
|
accept="*/*", |
|
|
|
body=body_jsonstr |
|
|
|
) |
|
|
|
except ClientError as ex: |
|
|
|
error_code = ex.response['Error']['Code'] |
|
|
|
full_error_msg = f"{error_code}: {ex.response['Error']['Message']}" |
|
|
|
raise self._map_client_to_invoke_error(error_code, full_error_msg) |
|
|
|
|
|
|
|
except (EndpointConnectionError, NoRegionError, ServiceNotInRegionError) as ex: |
|
|
|
raise InvokeConnectionError(str(ex)) |
|
|
|
|
|
|
|
except UnknownServiceError as ex: |
|
|
|
raise InvokeServerUnavailableError(str(ex)) |
|
|
|
|
|
|
|
except Exception as ex: |
|
|
|
raise InvokeError(str(ex)) |
|
|
|
|
|
|
|
if stream: |
|
|
|
return self._handle_generate_stream_response(model, credentials, response, prompt_messages) |
|
|
|
|
|
|
|
return self._handle_generate_response(model, credentials, response, prompt_messages) |
|
|
|
|
|
|
|
|
|
|
|
def _generate_anthropic(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, |
|
|
|
stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: |