| from typing import List, Optional, Any | from typing import List, Optional, Any | ||||
| from langchain.callbacks.manager import Callbacks | from langchain.callbacks.manager import Callbacks | ||||
| from langchain.llms import OpenLLM | |||||
| from langchain.schema import LLMResult | from langchain.schema import LLMResult | ||||
| from core.model_providers.error import LLMBadRequestError | from core.model_providers.error import LLMBadRequestError | ||||
| from core.model_providers.models.llm.base import BaseLLM | from core.model_providers.models.llm.base import BaseLLM | ||||
| from core.model_providers.models.entity.message import PromptMessage | from core.model_providers.models.entity.message import PromptMessage | ||||
| from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs | from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs | ||||
| from core.third_party.langchain.llms.openllm import OpenLLM | |||||
| class OpenLLMModel(BaseLLM): | class OpenLLMModel(BaseLLM): | ||||
| client = OpenLLM( | client = OpenLLM( | ||||
| server_url=self.credentials.get('server_url'), | server_url=self.credentials.get('server_url'), | ||||
| callbacks=self.callbacks, | callbacks=self.callbacks, | ||||
| **self.provider_model_kwargs | |||||
| llm_kwargs=self.provider_model_kwargs | |||||
| ) | ) | ||||
| return client | return client |
| import json | import json | ||||
| from typing import Type | from typing import Type | ||||
| from langchain.llms import OpenLLM | |||||
| from core.helper import encrypter | from core.helper import encrypter | ||||
| from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType | from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType | ||||
| from core.model_providers.models.llm.openllm_model import OpenLLMModel | from core.model_providers.models.llm.openllm_model import OpenLLMModel | ||||
| from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError | from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError | ||||
| from core.model_providers.models.base import BaseProviderModel | from core.model_providers.models.base import BaseProviderModel | ||||
| from core.third_party.langchain.llms.openllm import OpenLLM | |||||
| from models.provider import ProviderType | from models.provider import ProviderType | ||||
| :return: | :return: | ||||
| """ | """ | ||||
| return ModelKwargsRules( | return ModelKwargsRules( | ||||
| temperature=KwargRule[float](min=0, max=2, default=1), | |||||
| temperature=KwargRule[float](min=0.01, max=2, default=1), | |||||
| top_p=KwargRule[float](min=0, max=1, default=0.7), | top_p=KwargRule[float](min=0, max=1, default=0.7), | ||||
| presence_penalty=KwargRule[float](min=-2, max=2, default=0), | presence_penalty=KwargRule[float](min=-2, max=2, default=0), | ||||
| frequency_penalty=KwargRule[float](min=-2, max=2, default=0), | frequency_penalty=KwargRule[float](min=-2, max=2, default=0), | ||||
| max_tokens=KwargRule[int](min=10, max=4000, default=128), | |||||
| max_tokens=KwargRule[int](alias='max_new_tokens', min=10, max=4000, default=128), | |||||
| ) | ) | ||||
| @classmethod | @classmethod | ||||
| } | } | ||||
| llm = OpenLLM( | llm = OpenLLM( | ||||
| max_tokens=10, | |||||
| llm_kwargs={ | |||||
| 'max_new_tokens': 10 | |||||
| }, | |||||
| **credential_kwargs | **credential_kwargs | ||||
| ) | ) | ||||
| from __future__ import annotations | |||||
| import logging | |||||
| from typing import ( | |||||
| Any, | |||||
| Dict, | |||||
| List, | |||||
| Optional, | |||||
| ) | |||||
| import requests | |||||
| from langchain.llms.utils import enforce_stop_tokens | |||||
| from pydantic import Field | |||||
| from langchain.callbacks.manager import ( | |||||
| AsyncCallbackManagerForLLMRun, | |||||
| CallbackManagerForLLMRun, | |||||
| ) | |||||
| from langchain.llms.base import LLM | |||||
| logger = logging.getLogger(__name__) | |||||
| class OpenLLM(LLM): | |||||
| """OpenLLM, supporting both in-process model | |||||
| instance and remote OpenLLM servers. | |||||
| If you have a OpenLLM server running, you can also use it remotely: | |||||
| .. code-block:: python | |||||
| from langchain.llms import OpenLLM | |||||
| llm = OpenLLM(server_url='http://localhost:3000') | |||||
| llm("What is the difference between a duck and a goose?") | |||||
| """ | |||||
| server_url: Optional[str] = None | |||||
| """Optional server URL that currently runs a LLMServer with 'openllm start'.""" | |||||
| llm_kwargs: Dict[str, Any] = Field(default_factory=dict) | |||||
| """Key word arguments to be passed to openllm.LLM""" | |||||
| @property | |||||
| def _llm_type(self) -> str: | |||||
| return "openllm" | |||||
| def _call( | |||||
| self, | |||||
| prompt: str, | |||||
| stop: Optional[List[str]] = None, | |||||
| run_manager: CallbackManagerForLLMRun | None = None, | |||||
| **kwargs: Any, | |||||
| ) -> str: | |||||
| params = { | |||||
| "prompt": prompt, | |||||
| "llm_config": self.llm_kwargs | |||||
| } | |||||
| headers = {"Content-Type": "application/json"} | |||||
| response = requests.post( | |||||
| f'{self.server_url}/v1/generate', | |||||
| headers=headers, | |||||
| json=params | |||||
| ) | |||||
| if not response.ok: | |||||
| raise ValueError(f"OpenLLM HTTP {response.status_code} error: {response.text}") | |||||
| json_response = response.json() | |||||
| completion = json_response["responses"][0] | |||||
| if completion: | |||||
| completion = completion[len(prompt):] | |||||
| if stop is not None: | |||||
| completion = enforce_stop_tokens(completion, stop) | |||||
| return completion | |||||
| async def _acall( | |||||
| self, | |||||
| prompt: str, | |||||
| stop: Optional[List[str]] = None, | |||||
| run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, | |||||
| **kwargs: Any, | |||||
| ) -> str: | |||||
| raise NotImplementedError( | |||||
| "Async call is not supported for OpenLLM at the moment." | |||||
| ) |
| transformers~=4.31.0 | transformers~=4.31.0 | ||||
| stripe~=5.5.0 | stripe~=5.5.0 | ||||
| pandas==1.5.3 | pandas==1.5.3 | ||||
| xinference==0.2.0 | |||||
| openllm~=0.2.26 | |||||
| xinference==0.2.0 |
| def test_is_credentials_valid_or_raise_valid(mocker): | def test_is_credentials_valid_or_raise_valid(mocker): | ||||
| mocker.patch('langchain.llms.openllm.OpenLLM._identifying_params', return_value=None) | |||||
| mocker.patch('langchain.llms.openllm.OpenLLM._call', | |||||
| mocker.patch('core.third_party.langchain.llms.openllm.OpenLLM._call', | |||||
| return_value="abc") | return_value="abc") | ||||
| MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise( | MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise( | ||||
| def test_is_credentials_valid_or_raise_invalid(mocker): | def test_is_credentials_valid_or_raise_invalid(mocker): | ||||
| mocker.patch('langchain.llms.openllm.OpenLLM._identifying_params', return_value=None) | |||||
| # raise CredentialsValidateFailedError if credential is not in credentials | # raise CredentialsValidateFailedError if credential is not in credentials | ||||
| with pytest.raises(CredentialsValidateFailedError): | with pytest.raises(CredentialsValidateFailedError): | ||||
| MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise( | MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise( |