|
|
|
@@ -17,7 +17,6 @@ from botocore.exceptions import ( |
|
|
|
ServiceNotInRegionError, |
|
|
|
UnknownServiceError, |
|
|
|
) |
|
|
|
from cohere import ChatMessage |
|
|
|
|
|
|
|
# local import |
|
|
|
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta |
|
|
|
@@ -42,7 +41,6 @@ from core.model_runtime.errors.invoke import ( |
|
|
|
) |
|
|
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError |
|
|
|
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel |
|
|
|
from core.model_runtime.model_providers.cohere.llm.llm import CohereLargeLanguageModel |
|
|
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
@@ -59,6 +57,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel): |
|
|
|
{'prefix': 'mistral.mixtral-8x7b-instruct', 'support_system_prompts': False, 'support_tool_use': False}, |
|
|
|
{'prefix': 'mistral.mistral-large', 'support_system_prompts': True, 'support_tool_use': True}, |
|
|
|
{'prefix': 'mistral.mistral-small', 'support_system_prompts': True, 'support_tool_use': True}, |
|
|
|
{'prefix': 'cohere.command-r', 'support_system_prompts': True, 'support_tool_use': True}, |
|
|
|
{'prefix': 'amazon.titan', 'support_system_prompts': False, 'support_tool_use': False} |
|
|
|
] |
|
|
|
|
|
|
|
@@ -94,86 +93,8 @@ class BedrockLargeLanguageModel(LargeLanguageModel): |
|
|
|
model_info['model'] = model |
|
|
|
# invoke models via boto3 converse API |
|
|
|
return self._generate_with_converse(model_info, credentials, prompt_messages, model_parameters, stop, stream, user, tools) |
|
|
|
# invoke Cohere models via boto3 client |
|
|
|
if "cohere.command-r" in model: |
|
|
|
return self._generate_cohere_chat(model, credentials, prompt_messages, model_parameters, stop, stream, user, tools) |
|
|
|
# invoke other models via boto3 client |
|
|
|
return self._generate(model, credentials, prompt_messages, model_parameters, stop, stream, user) |
|
|
|
|
|
|
|
def _generate_cohere_chat( |
|
|
|
self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, |
|
|
|
stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None, |
|
|
|
tools: Optional[list[PromptMessageTool]] = None,) -> Union[LLMResult, Generator]: |
|
|
|
cohere_llm = CohereLargeLanguageModel() |
|
|
|
client_config = Config( |
|
|
|
region_name=credentials["aws_region"] |
|
|
|
) |
|
|
|
|
|
|
|
runtime_client = boto3.client( |
|
|
|
service_name='bedrock-runtime', |
|
|
|
config=client_config, |
|
|
|
aws_access_key_id=credentials["aws_access_key_id"], |
|
|
|
aws_secret_access_key=credentials["aws_secret_access_key"] |
|
|
|
) |
|
|
|
|
|
|
|
extra_model_kwargs = {} |
|
|
|
if stop: |
|
|
|
extra_model_kwargs['stop_sequences'] = stop |
|
|
|
|
|
|
|
if tools: |
|
|
|
tools = cohere_llm._convert_tools(tools) |
|
|
|
model_parameters['tools'] = tools |
|
|
|
|
|
|
|
message, chat_histories, tool_results \ |
|
|
|
= cohere_llm._convert_prompt_messages_to_message_and_chat_histories(prompt_messages) |
|
|
|
|
|
|
|
if tool_results: |
|
|
|
model_parameters['tool_results'] = tool_results |
|
|
|
|
|
|
|
payload = { |
|
|
|
**model_parameters, |
|
|
|
"message": message, |
|
|
|
"chat_history": chat_histories, |
|
|
|
} |
|
|
|
|
|
|
|
# need workaround for ai21 models which doesn't support streaming |
|
|
|
if stream: |
|
|
|
invoke = runtime_client.invoke_model_with_response_stream |
|
|
|
else: |
|
|
|
invoke = runtime_client.invoke_model |
|
|
|
|
|
|
|
def serialize(obj): |
|
|
|
if isinstance(obj, ChatMessage): |
|
|
|
return obj.__dict__ |
|
|
|
raise TypeError(f"Type {type(obj)} not serializable") |
|
|
|
|
|
|
|
try: |
|
|
|
body_jsonstr=json.dumps(payload, default=serialize) |
|
|
|
response = invoke( |
|
|
|
modelId=model, |
|
|
|
contentType="application/json", |
|
|
|
accept="*/*", |
|
|
|
body=body_jsonstr |
|
|
|
) |
|
|
|
except ClientError as ex: |
|
|
|
error_code = ex.response['Error']['Code'] |
|
|
|
full_error_msg = f"{error_code}: {ex.response['Error']['Message']}" |
|
|
|
raise self._map_client_to_invoke_error(error_code, full_error_msg) |
|
|
|
|
|
|
|
except (EndpointConnectionError, NoRegionError, ServiceNotInRegionError) as ex: |
|
|
|
raise InvokeConnectionError(str(ex)) |
|
|
|
|
|
|
|
except UnknownServiceError as ex: |
|
|
|
raise InvokeServerUnavailableError(str(ex)) |
|
|
|
|
|
|
|
except Exception as ex: |
|
|
|
raise InvokeError(str(ex)) |
|
|
|
|
|
|
|
if stream: |
|
|
|
return self._handle_generate_stream_response(model, credentials, response, prompt_messages) |
|
|
|
|
|
|
|
return self._handle_generate_response(model, credentials, response, prompt_messages) |
|
|
|
|
|
|
|
|
|
|
|
def _generate_with_converse(self, model_info: dict, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, |
|
|
|
stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None, tools: Optional[list[PromptMessageTool]] = None,) -> Union[LLMResult, Generator]: |
|
|
|
@@ -581,38 +502,9 @@ class BedrockLargeLanguageModel(LargeLanguageModel): |
|
|
|
:param message: PromptMessage to convert. |
|
|
|
:return: String representation of the message. |
|
|
|
""" |
|
|
|
|
|
|
|
if model_prefix == "anthropic": |
|
|
|
human_prompt_prefix = "\n\nHuman:" |
|
|
|
human_prompt_postfix = "" |
|
|
|
ai_prompt = "\n\nAssistant:" |
|
|
|
|
|
|
|
elif model_prefix == "meta": |
|
|
|
# LLAMA3 |
|
|
|
if model_name.startswith("llama3"): |
|
|
|
human_prompt_prefix = "<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n" |
|
|
|
human_prompt_postfix = "<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" |
|
|
|
ai_prompt = "\n\nAssistant:" |
|
|
|
else: |
|
|
|
# LLAMA2 |
|
|
|
human_prompt_prefix = "\n[INST]" |
|
|
|
human_prompt_postfix = "[\\INST]\n" |
|
|
|
ai_prompt = "" |
|
|
|
|
|
|
|
elif model_prefix == "mistral": |
|
|
|
human_prompt_prefix = "<s>[INST]" |
|
|
|
human_prompt_postfix = "[\\INST]\n" |
|
|
|
ai_prompt = "\n\nAssistant:" |
|
|
|
|
|
|
|
elif model_prefix == "amazon": |
|
|
|
human_prompt_prefix = "\n\nUser:" |
|
|
|
human_prompt_postfix = "" |
|
|
|
ai_prompt = "\n\nBot:" |
|
|
|
|
|
|
|
else: |
|
|
|
human_prompt_prefix = "" |
|
|
|
human_prompt_postfix = "" |
|
|
|
ai_prompt = "" |
|
|
|
human_prompt_prefix = "" |
|
|
|
human_prompt_postfix = "" |
|
|
|
ai_prompt = "" |
|
|
|
|
|
|
|
content = message.content |
|
|
|
|
|
|
|
@@ -663,13 +555,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel): |
|
|
|
model_prefix = model.split('.')[0] |
|
|
|
model_name = model.split('.')[1] |
|
|
|
|
|
|
|
if model_prefix == "amazon": |
|
|
|
payload["textGenerationConfig"] = { **model_parameters } |
|
|
|
payload["textGenerationConfig"]["stopSequences"] = ["User:"] |
|
|
|
|
|
|
|
payload["inputText"] = self._convert_messages_to_prompt(prompt_messages, model_prefix) |
|
|
|
|
|
|
|
elif model_prefix == "ai21": |
|
|
|
if model_prefix == "ai21": |
|
|
|
payload["temperature"] = model_parameters.get("temperature") |
|
|
|
payload["topP"] = model_parameters.get("topP") |
|
|
|
payload["maxTokens"] = model_parameters.get("maxTokens") |
|
|
|
@@ -681,28 +567,12 @@ class BedrockLargeLanguageModel(LargeLanguageModel): |
|
|
|
payload["frequencyPenalty"] = {model_parameters.get("frequencyPenalty")} |
|
|
|
if model_parameters.get("countPenalty"): |
|
|
|
payload["countPenalty"] = {model_parameters.get("countPenalty")} |
|
|
|
|
|
|
|
elif model_prefix == "mistral": |
|
|
|
payload["temperature"] = model_parameters.get("temperature") |
|
|
|
payload["top_p"] = model_parameters.get("top_p") |
|
|
|
payload["max_tokens"] = model_parameters.get("max_tokens") |
|
|
|
payload["prompt"] = self._convert_messages_to_prompt(prompt_messages, model_prefix) |
|
|
|
payload["stop"] = stop[:10] if stop else [] |
|
|
|
|
|
|
|
elif model_prefix == "anthropic": |
|
|
|
payload = { **model_parameters } |
|
|
|
payload["prompt"] = self._convert_messages_to_prompt(prompt_messages, model_prefix) |
|
|
|
payload["stop_sequences"] = ["\n\nHuman:"] + (stop if stop else []) |
|
|
|
|
|
|
|
|
|
|
|
elif model_prefix == "cohere": |
|
|
|
payload = { **model_parameters } |
|
|
|
payload["prompt"] = prompt_messages[0].content |
|
|
|
payload["stream"] = stream |
|
|
|
|
|
|
|
elif model_prefix == "meta": |
|
|
|
payload = { **model_parameters } |
|
|
|
payload["prompt"] = self._convert_messages_to_prompt(prompt_messages, model_prefix, model_name) |
|
|
|
|
|
|
|
else: |
|
|
|
raise ValueError(f"Got unknown model prefix {model_prefix}") |
|
|
|
|
|
|
|
@@ -793,36 +663,16 @@ class BedrockLargeLanguageModel(LargeLanguageModel): |
|
|
|
# get output text and calculate num tokens based on model / provider |
|
|
|
model_prefix = model.split('.')[0] |
|
|
|
|
|
|
|
if model_prefix == "amazon": |
|
|
|
output = response_body.get("results")[0].get("outputText").strip('\n') |
|
|
|
prompt_tokens = response_body.get("inputTextTokenCount") |
|
|
|
completion_tokens = response_body.get("results")[0].get("tokenCount") |
|
|
|
|
|
|
|
elif model_prefix == "ai21": |
|
|
|
if model_prefix == "ai21": |
|
|
|
output = response_body.get('completions')[0].get('data').get('text') |
|
|
|
prompt_tokens = len(response_body.get("prompt").get("tokens")) |
|
|
|
completion_tokens = len(response_body.get('completions')[0].get('data').get('tokens')) |
|
|
|
|
|
|
|
elif model_prefix == "anthropic": |
|
|
|
output = response_body.get("completion") |
|
|
|
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) |
|
|
|
completion_tokens = self.get_num_tokens(model, credentials, output if output else '') |
|
|
|
|
|
|
|
elif model_prefix == "cohere": |
|
|
|
output = response_body.get("generations")[0].get("text") |
|
|
|
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) |
|
|
|
completion_tokens = self.get_num_tokens(model, credentials, output if output else '') |
|
|
|
|
|
|
|
elif model_prefix == "meta": |
|
|
|
output = response_body.get("generation").strip('\n') |
|
|
|
prompt_tokens = response_body.get("prompt_token_count") |
|
|
|
completion_tokens = response_body.get("generation_token_count") |
|
|
|
|
|
|
|
elif model_prefix == "mistral": |
|
|
|
output = response_body.get("outputs")[0].get("text") |
|
|
|
prompt_tokens = response.get('ResponseMetadata').get('HTTPHeaders').get('x-amzn-bedrock-input-token-count') |
|
|
|
completion_tokens = response.get('ResponseMetadata').get('HTTPHeaders').get('x-amzn-bedrock-output-token-count') |
|
|
|
|
|
|
|
|
|
|
|
else: |
|
|
|
raise ValueError(f"Got unknown model prefix {model_prefix} when handling block response") |
|
|
|
|
|
|
|
@@ -893,26 +743,10 @@ class BedrockLargeLanguageModel(LargeLanguageModel): |
|
|
|
payload = json.loads(chunk.get('bytes').decode()) |
|
|
|
|
|
|
|
model_prefix = model.split('.')[0] |
|
|
|
if model_prefix == "amazon": |
|
|
|
content_delta = payload.get("outputText").strip('\n') |
|
|
|
finish_reason = payload.get("completion_reason") |
|
|
|
|
|
|
|
elif model_prefix == "anthropic": |
|
|
|
content_delta = payload.get("completion") |
|
|
|
finish_reason = payload.get("stop_reason") |
|
|
|
|
|
|
|
elif model_prefix == "cohere": |
|
|
|
if model_prefix == "cohere": |
|
|
|
content_delta = payload.get("text") |
|
|
|
finish_reason = payload.get("finish_reason") |
|
|
|
|
|
|
|
elif model_prefix == "mistral": |
|
|
|
content_delta = payload.get('outputs')[0].get("text") |
|
|
|
finish_reason = payload.get('outputs')[0].get("stop_reason") |
|
|
|
|
|
|
|
elif model_prefix == "meta": |
|
|
|
content_delta = payload.get("generation").strip('\n') |
|
|
|
finish_reason = payload.get("stop_reason") |
|
|
|
|
|
|
|
else: |
|
|
|
raise ValueError(f"Got unknown model prefix {model_prefix} when handling stream response") |
|
|
|
|