|
|
|
@@ -370,29 +370,14 @@ class BedrockLargeLanguageModel(LargeLanguageModel): |
|
|
|
:return:md = genai.GenerativeModel(model) |
|
|
|
""" |
|
|
|
prefix = model.split('.')[0] |
|
|
|
|
|
|
|
model_name = model.split('.')[1] |
|
|
|
if isinstance(messages, str): |
|
|
|
prompt = messages |
|
|
|
else: |
|
|
|
prompt = self._convert_messages_to_prompt(messages, prefix) |
|
|
|
prompt = self._convert_messages_to_prompt(messages, prefix, model_name) |
|
|
|
|
|
|
|
return self._get_num_tokens_by_gpt2(prompt) |
|
|
|
|
|
|
|
def _convert_messages_to_prompt(self, model_prefix: str, messages: list[PromptMessage]) -> str: |
|
|
|
""" |
|
|
|
Format a list of messages into a full prompt for the Google model |
|
|
|
|
|
|
|
:param messages: List of PromptMessage to combine. |
|
|
|
:return: Combined string with necessary human_prompt and ai_prompt tags. |
|
|
|
""" |
|
|
|
messages = messages.copy() # don't mutate the original list |
|
|
|
|
|
|
|
text = "".join( |
|
|
|
self._convert_one_message_to_text(message, model_prefix) |
|
|
|
for message in messages |
|
|
|
) |
|
|
|
|
|
|
|
return text.rstrip() |
|
|
|
|
|
|
|
def validate_credentials(self, model: str, credentials: dict) -> None: |
|
|
|
""" |
|
|
|
@@ -432,7 +417,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel): |
|
|
|
except Exception as ex: |
|
|
|
raise CredentialsValidateFailedError(str(ex)) |
|
|
|
|
|
|
|
def _convert_one_message_to_text(self, message: PromptMessage, model_prefix: str) -> str: |
|
|
|
def _convert_one_message_to_text(self, message: PromptMessage, model_prefix: str, model_name: Optional[str] = None) -> str: |
|
|
|
""" |
|
|
|
Convert a single message to a string. |
|
|
|
|
|
|
|
@@ -446,10 +431,17 @@ class BedrockLargeLanguageModel(LargeLanguageModel): |
|
|
|
ai_prompt = "\n\nAssistant:" |
|
|
|
|
|
|
|
elif model_prefix == "meta": |
|
|
|
human_prompt_prefix = "\n[INST]" |
|
|
|
human_prompt_postfix = "[\\INST]\n" |
|
|
|
ai_prompt = "" |
|
|
|
|
|
|
|
# LLAMA3 |
|
|
|
if model_name.startswith("llama3"): |
|
|
|
human_prompt_prefix = "<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n" |
|
|
|
human_prompt_postfix = "<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" |
|
|
|
ai_prompt = "\n\nAssistant:" |
|
|
|
else: |
|
|
|
# LLAMA2 |
|
|
|
human_prompt_prefix = "\n[INST]" |
|
|
|
human_prompt_postfix = "[\\INST]\n" |
|
|
|
ai_prompt = "" |
|
|
|
|
|
|
|
elif model_prefix == "mistral": |
|
|
|
human_prompt_prefix = "<s>[INST]" |
|
|
|
human_prompt_postfix = "[\\INST]\n" |
|
|
|
@@ -478,11 +470,12 @@ class BedrockLargeLanguageModel(LargeLanguageModel): |
|
|
|
|
|
|
|
return message_text |
|
|
|
|
|
|
|
def _convert_messages_to_prompt(self, messages: list[PromptMessage], model_prefix: str) -> str: |
|
|
|
def _convert_messages_to_prompt(self, messages: list[PromptMessage], model_prefix: str, model_name: Optional[str] = None) -> str: |
|
|
|
""" |
|
|
|
Format a list of messages into a full prompt for the Anthropic, Amazon and Llama models |
|
|
|
|
|
|
|
:param messages: List of PromptMessage to combine. |
|
|
|
:param model_name: specific model name.Optional,just to distinguish llama2 and llama3 |
|
|
|
:return: Combined string with necessary human_prompt and ai_prompt tags. |
|
|
|
""" |
|
|
|
if not messages: |
|
|
|
@@ -493,18 +486,20 @@ class BedrockLargeLanguageModel(LargeLanguageModel): |
|
|
|
messages.append(AssistantPromptMessage(content="")) |
|
|
|
|
|
|
|
text = "".join( |
|
|
|
self._convert_one_message_to_text(message, model_prefix) |
|
|
|
self._convert_one_message_to_text(message, model_prefix, model_name) |
|
|
|
for message in messages |
|
|
|
) |
|
|
|
|
|
|
|
# trim off the trailing ' ' that might come from the "Assistant: " |
|
|
|
return text.rstrip() |
|
|
|
|
|
|
|
def _create_payload(self, model_prefix: str, prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[list[str]] = None, stream: bool = True): |
|
|
|
def _create_payload(self, model: str, prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[list[str]] = None, stream: bool = True): |
|
|
|
""" |
|
|
|
Create payload for bedrock api call depending on model provider |
|
|
|
""" |
|
|
|
payload = dict() |
|
|
|
model_prefix = model.split('.')[0] |
|
|
|
model_name = model.split('.')[1] |
|
|
|
|
|
|
|
if model_prefix == "amazon": |
|
|
|
payload["textGenerationConfig"] = { **model_parameters } |
|
|
|
@@ -544,7 +539,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel): |
|
|
|
|
|
|
|
elif model_prefix == "meta": |
|
|
|
payload = { **model_parameters } |
|
|
|
payload["prompt"] = self._convert_messages_to_prompt(prompt_messages, model_prefix) |
|
|
|
payload["prompt"] = self._convert_messages_to_prompt(prompt_messages, model_prefix, model_name) |
|
|
|
|
|
|
|
else: |
|
|
|
raise ValueError(f"Got unknown model prefix {model_prefix}") |
|
|
|
@@ -579,7 +574,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel): |
|
|
|
) |
|
|
|
|
|
|
|
model_prefix = model.split('.')[0] |
|
|
|
payload = self._create_payload(model_prefix, prompt_messages, model_parameters, stop, stream) |
|
|
|
payload = self._create_payload(model, prompt_messages, model_parameters, stop, stream) |
|
|
|
|
|
|
|
# need workaround for ai21 models which doesn't support streaming |
|
|
|
if stream and model_prefix != "ai21": |