瀏覽代碼

fix: Fix the problem of system not working (#2884)

tags/0.5.10
Su Yang 1 年之前
父節點
當前提交
507aa6d949
沒有連結到貢獻者的電子郵件帳戶。
共有 1 個檔案被更改,包括 36 行新增11 行删除
  1. 36
    11
      api/core/model_runtime/model_providers/bedrock/llm/llm.py

+ 36
- 11
api/core/model_runtime/model_providers/bedrock/llm/llm.py 查看文件

@@ -74,12 +74,12 @@ class BedrockLargeLanguageModel(LargeLanguageModel):

# invoke claude 3 models via anthropic official SDK
if "anthropic.claude-3" in model:
return self._invoke_claude3(model, credentials, prompt_messages, model_parameters, stop, stream)
return self._invoke_claude3(model, credentials, prompt_messages, model_parameters, stop, stream, user)
# invoke model
return self._generate(model, credentials, prompt_messages, model_parameters, stop, stream, user)

def _invoke_claude3(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict,
stop: Optional[list[str]] = None, stream: bool = True) -> Union[LLMResult, Generator]:
stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]:
"""
Invoke Claude3 large language model

@@ -100,22 +100,38 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
aws_region=credentials["aws_region"],
)

extra_model_kwargs = {}
if stop:
extra_model_kwargs['stop_sequences'] = stop

# Notice: If you request the current version of the SDK to the bedrock server,
# you will get the following error message and you need to wait for the service or SDK to be updated.
# Response: Error code: 400
# {'message': 'Malformed input request: #: subject must not be valid against schema
# {"required":["messages"]}#: extraneous key [metadata] is not permitted, please reformat your input and try again.'}
# TODO: Open in the future when the interface is properly supported
# if user:
# ref: https://github.com/anthropics/anthropic-sdk-python/blob/e84645b07ca5267066700a104b4d8d6a8da1383d/src/anthropic/resources/messages.py#L465
# extra_model_kwargs['metadata'] = message_create_params.Metadata(user_id=user)

system, prompt_message_dicts = self._convert_claude3_prompt_messages(prompt_messages)

if system:
extra_model_kwargs['system'] = system

response = client.messages.create(
model=model,
messages=prompt_message_dicts,
stop_sequences=stop if stop else [],
system=system,
stream=stream,
**model_parameters,
**extra_model_kwargs
)

if stream is False:
return self._handle_claude3_response(model, credentials, response, prompt_messages)
else:
if stream:
return self._handle_claude3_stream_response(model, credentials, response, prompt_messages)

return self._handle_claude3_response(model, credentials, response, prompt_messages)

def _handle_claude3_response(self, model: str, credentials: dict, response: Message,
prompt_messages: list[PromptMessage]) -> LLMResult:
"""
@@ -263,13 +279,22 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
"""
Convert prompt messages to dict list and system
"""
system = ""
prompt_message_dicts = []

system = ""
first_loop = True
for message in prompt_messages:
if isinstance(message, SystemPromptMessage):
system += message.content + ("\n" if not system else "")
else:
message.content=message.content.strip()
if first_loop:
system=message.content
first_loop=False
else:
system+="\n"
system+=message.content

prompt_message_dicts = []
for message in prompt_messages:
if not isinstance(message, SystemPromptMessage):
prompt_message_dicts.append(self._convert_claude3_prompt_message_to_dict(message))

return system, prompt_message_dicts

Loading…
取消
儲存