| @@ -1,5 +1,6 @@ | |||
| import json | |||
| import logging | |||
| import re | |||
| from collections.abc import Generator, Iterator | |||
| from typing import Any, Optional, Union, cast | |||
| @@ -131,115 +132,58 @@ class SageMakerLargeLanguageModel(LargeLanguageModel): | |||
| """ | |||
| handle stream chat generate response | |||
| """ | |||
| class ChunkProcessor: | |||
| def __init__(self): | |||
| self.buffer = bytearray() | |||
| def try_decode_chunk(self, chunk: bytes) -> list[dict]: | |||
| """尝试从chunk中解码出完整的JSON对象""" | |||
| self.buffer.extend(chunk) | |||
| results = [] | |||
| while True: | |||
| try: | |||
| start = self.buffer.find(b"{") | |||
| if start == -1: | |||
| self.buffer.clear() | |||
| break | |||
| bracket_count = 0 | |||
| end = start | |||
| for i in range(start, len(self.buffer)): | |||
| if self.buffer[i] == ord("{"): | |||
| bracket_count += 1 | |||
| elif self.buffer[i] == ord("}"): | |||
| bracket_count -= 1 | |||
| if bracket_count == 0: | |||
| end = i + 1 | |||
| break | |||
| if bracket_count != 0: | |||
| # JSON不完整,等待更多数据 | |||
| if start > 0: | |||
| self.buffer = self.buffer[start:] | |||
| break | |||
| json_bytes = self.buffer[start:end] | |||
| try: | |||
| data = json.loads(json_bytes) | |||
| results.append(data) | |||
| self.buffer = self.buffer[end:] | |||
| except json.JSONDecodeError: | |||
| self.buffer = self.buffer[start + 1 :] | |||
| except Exception as e: | |||
| logger.debug(f"Warning: Error processing chunk ({str(e)})") | |||
| if start > 0: | |||
| self.buffer = self.buffer[start:] | |||
| break | |||
| return results | |||
| full_response = "" | |||
| processor = ChunkProcessor() | |||
| try: | |||
| for chunk in resp: | |||
| json_objects = processor.try_decode_chunk(chunk) | |||
| for data in json_objects: | |||
| if data.get("choices"): | |||
| choice = data["choices"][0] | |||
| if "delta" in choice and "content" in choice["delta"]: | |||
| chunk_content = choice["delta"]["content"] | |||
| assistant_prompt_message = AssistantPromptMessage(content=chunk_content, tool_calls=[]) | |||
| if choice.get("finish_reason") is not None: | |||
| temp_assistant_prompt_message = AssistantPromptMessage( | |||
| content=full_response, tool_calls=[] | |||
| ) | |||
| prompt_tokens = self._num_tokens_from_messages(messages=prompt_messages, tools=tools) | |||
| completion_tokens = self._num_tokens_from_messages( | |||
| messages=[temp_assistant_prompt_message], tools=[] | |||
| ) | |||
| usage = self._calc_response_usage( | |||
| model=model, | |||
| credentials=credentials, | |||
| prompt_tokens=prompt_tokens, | |||
| completion_tokens=completion_tokens, | |||
| ) | |||
| yield LLMResultChunk( | |||
| model=model, | |||
| prompt_messages=prompt_messages, | |||
| system_fingerprint=None, | |||
| delta=LLMResultChunkDelta( | |||
| index=0, | |||
| message=assistant_prompt_message, | |||
| finish_reason=choice["finish_reason"], | |||
| usage=usage, | |||
| ), | |||
| ) | |||
| else: | |||
| yield LLMResultChunk( | |||
| model=model, | |||
| prompt_messages=prompt_messages, | |||
| system_fingerprint=None, | |||
| delta=LLMResultChunkDelta(index=0, message=assistant_prompt_message), | |||
| ) | |||
| full_response += chunk_content | |||
| except Exception as e: | |||
| raise | |||
| if not full_response: | |||
| logger.warning("No content received from stream response") | |||
| buffer = "" | |||
| for chunk_bytes in resp: | |||
| buffer += chunk_bytes.decode("utf-8") | |||
| last_idx = 0 | |||
| for match in re.finditer(r"^data:\s*(.+?)(\n\n)", buffer): | |||
| try: | |||
| data = json.loads(match.group(1).strip()) | |||
| last_idx = match.span()[1] | |||
| if "content" in data["choices"][0]["delta"]: | |||
| chunk_content = data["choices"][0]["delta"]["content"] | |||
| assistant_prompt_message = AssistantPromptMessage(content=chunk_content, tool_calls=[]) | |||
| if data["choices"][0]["finish_reason"] is not None: | |||
| temp_assistant_prompt_message = AssistantPromptMessage(content=full_response, tool_calls=[]) | |||
| prompt_tokens = self._num_tokens_from_messages(messages=prompt_messages, tools=tools) | |||
| completion_tokens = self._num_tokens_from_messages( | |||
| messages=[temp_assistant_prompt_message], tools=[] | |||
| ) | |||
| usage = self._calc_response_usage( | |||
| model=model, | |||
| credentials=credentials, | |||
| prompt_tokens=prompt_tokens, | |||
| completion_tokens=completion_tokens, | |||
| ) | |||
| yield LLMResultChunk( | |||
| model=model, | |||
| prompt_messages=prompt_messages, | |||
| system_fingerprint=None, | |||
| delta=LLMResultChunkDelta( | |||
| index=0, | |||
| message=assistant_prompt_message, | |||
| finish_reason=data["choices"][0]["finish_reason"], | |||
| usage=usage, | |||
| ), | |||
| ) | |||
| else: | |||
| yield LLMResultChunk( | |||
| model=model, | |||
| prompt_messages=prompt_messages, | |||
| system_fingerprint=None, | |||
| delta=LLMResultChunkDelta(index=0, message=assistant_prompt_message), | |||
| ) | |||
| full_response += chunk_content | |||
| except (json.JSONDecodeError, KeyError, IndexError) as e: | |||
| logger.info("json parse exception, content: {}".format(match.group(1).strip())) | |||
| pass | |||
| buffer = buffer[last_idx:] | |||
| def _invoke( | |||
| self, | |||