|
|
|
@@ -1,7 +1,7 @@ |
|
|
|
from collections.abc import Generator |
|
|
|
from typing import Optional, Union |
|
|
|
|
|
|
|
from core.model_runtime.entities.llm_entities import LLMResult |
|
|
|
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta |
|
|
|
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool |
|
|
|
from core.model_runtime.entities.model_entities import AIModelEntity |
|
|
|
from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel |
|
|
|
@@ -26,7 +26,7 @@ class OpenRouterLargeLanguageModel(OAIAPICompatLargeLanguageModel): |
|
|
|
) -> Union[LLMResult, Generator]: |
|
|
|
self._update_credential(model, credentials) |
|
|
|
|
|
|
|
return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) |
|
|
|
return self._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) |
|
|
|
|
|
|
|
def validate_credentials(self, model: str, credentials: dict) -> None: |
|
|
|
self._update_credential(model, credentials) |
|
|
|
@@ -46,7 +46,48 @@ class OpenRouterLargeLanguageModel(OAIAPICompatLargeLanguageModel): |
|
|
|
) -> Union[LLMResult, Generator]: |
|
|
|
self._update_credential(model, credentials) |
|
|
|
|
|
|
|
return super()._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) |
|
|
|
block_as_stream = False |
|
|
|
if model.startswith("openai/o1"): |
|
|
|
block_as_stream = True |
|
|
|
stop = None |
|
|
|
|
|
|
|
# invoke block as stream |
|
|
|
if stream and block_as_stream: |
|
|
|
return self._generate_block_as_stream( |
|
|
|
model, credentials, prompt_messages, model_parameters, tools, stop, user |
|
|
|
) |
|
|
|
else: |
|
|
|
return super()._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) |
|
|
|
|
|
|
|
def _generate_block_as_stream( |
|
|
|
self, |
|
|
|
model: str, |
|
|
|
credentials: dict, |
|
|
|
prompt_messages: list[PromptMessage], |
|
|
|
model_parameters: dict, |
|
|
|
tools: Optional[list[PromptMessageTool]] = None, |
|
|
|
stop: Optional[list[str]] = None, |
|
|
|
user: Optional[str] = None, |
|
|
|
) -> Generator: |
|
|
|
resp: LLMResult = super()._generate( |
|
|
|
model, credentials, prompt_messages, model_parameters, tools, stop, False, user |
|
|
|
) |
|
|
|
|
|
|
|
yield LLMResultChunk( |
|
|
|
model=model, |
|
|
|
prompt_messages=prompt_messages, |
|
|
|
delta=LLMResultChunkDelta( |
|
|
|
index=0, |
|
|
|
message=resp.message, |
|
|
|
usage=self._calc_response_usage( |
|
|
|
model=model, |
|
|
|
credentials=credentials, |
|
|
|
prompt_tokens=resp.usage.prompt_tokens, |
|
|
|
completion_tokens=resp.usage.completion_tokens, |
|
|
|
), |
|
|
|
finish_reason="stop", |
|
|
|
), |
|
|
|
) |
|
|
|
|
|
|
|
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: |
|
|
|
self._update_credential(model, credentials) |