| 
				
			 | 
			
			 | 
			@@ -25,6 +25,7 @@ from botocore.exceptions import ( | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    ServiceNotInRegionError, | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    UnknownServiceError, | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			) | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			from cohere import ChatMessage | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
  | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			from core.model_runtime.entities.message_entities import ( | 
		
		
	
	
		
			
			| 
				
			 | 
			
			 | 
			@@ -48,6 +49,7 @@ from core.model_runtime.errors.invoke import ( | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			) | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			from core.model_runtime.errors.validate import CredentialsValidateFailedError | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			from core.model_runtime.model_providers.cohere.llm.llm import CohereLargeLanguageModel | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
  | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			logger = logging.getLogger(__name__) | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
  | 
		
		
	
	
		
			
			| 
				
			 | 
			
			 | 
			@@ -75,8 +77,86 @@ class BedrockLargeLanguageModel(LargeLanguageModel): | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        # invoke anthropic models via anthropic official SDK | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        if "anthropic" in model: | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			            return self._generate_anthropic(model, credentials, prompt_messages, model_parameters, stop, stream, user) | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        # invoke Cohere models via boto3 client | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        if "cohere.command-r" in model: | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			            return self._generate_cohere_chat(model, credentials, prompt_messages, model_parameters, stop, stream, user, tools) | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        # invoke other models via boto3 client | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        return self._generate(model, credentials, prompt_messages, model_parameters, stop, stream, user) | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			     | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    def _generate_cohere_chat( | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			            self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			            stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None, | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			            tools: Optional[list[PromptMessageTool]] = None,) -> Union[LLMResult, Generator]: | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        cohere_llm = CohereLargeLanguageModel() | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        client_config = Config( | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			            region_name=credentials["aws_region"] | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        ) | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
  | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        runtime_client = boto3.client( | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			            service_name='bedrock-runtime', | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			            config=client_config, | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			            aws_access_key_id=credentials["aws_access_key_id"], | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			            aws_secret_access_key=credentials["aws_secret_access_key"] | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        ) | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
  | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        extra_model_kwargs = {} | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        if stop: | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			            extra_model_kwargs['stop_sequences'] = stop | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
  | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        if tools: | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			            tools = cohere_llm._convert_tools(tools) | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			            model_parameters['tools'] = tools | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
  | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        message, chat_histories, tool_results \ | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			            = cohere_llm._convert_prompt_messages_to_message_and_chat_histories(prompt_messages) | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
  | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        if tool_results: | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			            model_parameters['tool_results'] = tool_results | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
  | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        payload = { | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			            **model_parameters, | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			            "message": message, | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			            "chat_history": chat_histories, | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        } | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
  | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        # need workaround for ai21 models which doesn't support streaming | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        if stream: | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			            invoke = runtime_client.invoke_model_with_response_stream | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        else: | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			            invoke = runtime_client.invoke_model | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
  | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        def serialize(obj): | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			            if isinstance(obj, ChatMessage): | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			                return obj.__dict__ | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			            raise TypeError(f"Type {type(obj)} not serializable") | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
  | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        try: | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			            body_jsonstr=json.dumps(payload, default=serialize) | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			            response = invoke( | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			                modelId=model, | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			                contentType="application/json", | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			                accept="*/*", | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			                body=body_jsonstr | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			            ) | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        except ClientError as ex: | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			            error_code = ex.response['Error']['Code'] | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			            full_error_msg = f"{error_code}: {ex.response['Error']['Message']}" | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			            raise self._map_client_to_invoke_error(error_code, full_error_msg) | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
  | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        except (EndpointConnectionError, NoRegionError, ServiceNotInRegionError) as ex: | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			            raise InvokeConnectionError(str(ex)) | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
  | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        except UnknownServiceError as ex: | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			            raise InvokeServerUnavailableError(str(ex)) | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
  | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        except Exception as ex: | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			            raise InvokeError(str(ex)) | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
  | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        if stream: | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			            return self._handle_generate_stream_response(model, credentials, response, prompt_messages) | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
  | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			        return self._handle_generate_response(model, credentials, response, prompt_messages) | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
  | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			
  | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			    def _generate_anthropic(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, | 
		
		
	
		
			
			 | 
			 | 
			
			 | 
			                stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: |