| @@ -1,13 +1,13 @@ | |||
| from typing import List, Optional, Any | |||
| from langchain.callbacks.manager import Callbacks | |||
| from langchain.llms import OpenLLM | |||
| from langchain.schema import LLMResult | |||
| from core.model_providers.error import LLMBadRequestError | |||
| from core.model_providers.models.llm.base import BaseLLM | |||
| from core.model_providers.models.entity.message import PromptMessage | |||
| from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs | |||
| from core.third_party.langchain.llms.openllm import OpenLLM | |||
| class OpenLLMModel(BaseLLM): | |||
| @@ -19,7 +19,7 @@ class OpenLLMModel(BaseLLM): | |||
| client = OpenLLM( | |||
| server_url=self.credentials.get('server_url'), | |||
| callbacks=self.callbacks, | |||
| **self.provider_model_kwargs | |||
| llm_kwargs=self.provider_model_kwargs | |||
| ) | |||
| return client | |||
| @@ -1,14 +1,13 @@ | |||
| import json | |||
| from typing import Type | |||
| from langchain.llms import OpenLLM | |||
| from core.helper import encrypter | |||
| from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType | |||
| from core.model_providers.models.llm.openllm_model import OpenLLMModel | |||
| from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError | |||
| from core.model_providers.models.base import BaseProviderModel | |||
| from core.third_party.langchain.llms.openllm import OpenLLM | |||
| from models.provider import ProviderType | |||
| @@ -46,11 +45,11 @@ class OpenLLMProvider(BaseModelProvider): | |||
| :return: | |||
| """ | |||
| return ModelKwargsRules( | |||
| temperature=KwargRule[float](min=0, max=2, default=1), | |||
| temperature=KwargRule[float](min=0.01, max=2, default=1), | |||
| top_p=KwargRule[float](min=0, max=1, default=0.7), | |||
| presence_penalty=KwargRule[float](min=-2, max=2, default=0), | |||
| frequency_penalty=KwargRule[float](min=-2, max=2, default=0), | |||
| max_tokens=KwargRule[int](min=10, max=4000, default=128), | |||
| max_tokens=KwargRule[int](alias='max_new_tokens', min=10, max=4000, default=128), | |||
| ) | |||
| @classmethod | |||
| @@ -71,7 +70,9 @@ class OpenLLMProvider(BaseModelProvider): | |||
| } | |||
| llm = OpenLLM( | |||
| max_tokens=10, | |||
| llm_kwargs={ | |||
| 'max_new_tokens': 10 | |||
| }, | |||
| **credential_kwargs | |||
| ) | |||
| @@ -0,0 +1,87 @@ | |||
| from __future__ import annotations | |||
| import logging | |||
| from typing import ( | |||
| Any, | |||
| Dict, | |||
| List, | |||
| Optional, | |||
| ) | |||
| import requests | |||
| from langchain.llms.utils import enforce_stop_tokens | |||
| from pydantic import Field | |||
| from langchain.callbacks.manager import ( | |||
| AsyncCallbackManagerForLLMRun, | |||
| CallbackManagerForLLMRun, | |||
| ) | |||
| from langchain.llms.base import LLM | |||
| logger = logging.getLogger(__name__) | |||
| class OpenLLM(LLM): | |||
| """OpenLLM, supporting both in-process model | |||
| instance and remote OpenLLM servers. | |||
| If you have a OpenLLM server running, you can also use it remotely: | |||
| .. code-block:: python | |||
| from langchain.llms import OpenLLM | |||
| llm = OpenLLM(server_url='http://localhost:3000') | |||
| llm("What is the difference between a duck and a goose?") | |||
| """ | |||
| server_url: Optional[str] = None | |||
| """Optional server URL that currently runs a LLMServer with 'openllm start'.""" | |||
| llm_kwargs: Dict[str, Any] = Field(default_factory=dict) | |||
| """Key word arguments to be passed to openllm.LLM""" | |||
| @property | |||
| def _llm_type(self) -> str: | |||
| return "openllm" | |||
| def _call( | |||
| self, | |||
| prompt: str, | |||
| stop: Optional[List[str]] = None, | |||
| run_manager: CallbackManagerForLLMRun | None = None, | |||
| **kwargs: Any, | |||
| ) -> str: | |||
| params = { | |||
| "prompt": prompt, | |||
| "llm_config": self.llm_kwargs | |||
| } | |||
| headers = {"Content-Type": "application/json"} | |||
| response = requests.post( | |||
| f'{self.server_url}/v1/generate', | |||
| headers=headers, | |||
| json=params | |||
| ) | |||
| if not response.ok: | |||
| raise ValueError(f"OpenLLM HTTP {response.status_code} error: {response.text}") | |||
| json_response = response.json() | |||
| completion = json_response["responses"][0] | |||
| if completion: | |||
| completion = completion[len(prompt):] | |||
| if stop is not None: | |||
| completion = enforce_stop_tokens(completion, stop) | |||
| return completion | |||
| async def _acall( | |||
| self, | |||
| prompt: str, | |||
| stop: Optional[List[str]] = None, | |||
| run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, | |||
| **kwargs: Any, | |||
| ) -> str: | |||
| raise NotImplementedError( | |||
| "Async call is not supported for OpenLLM at the moment." | |||
| ) | |||
| @@ -49,5 +49,4 @@ huggingface_hub~=0.16.4 | |||
| transformers~=4.31.0 | |||
| stripe~=5.5.0 | |||
| pandas==1.5.3 | |||
| xinference==0.2.0 | |||
| openllm~=0.2.26 | |||
| xinference==0.2.0 | |||
| @@ -23,8 +23,7 @@ def decrypt_side_effect(tenant_id, encrypted_key): | |||
| def test_is_credentials_valid_or_raise_valid(mocker): | |||
| mocker.patch('langchain.llms.openllm.OpenLLM._identifying_params', return_value=None) | |||
| mocker.patch('langchain.llms.openllm.OpenLLM._call', | |||
| mocker.patch('core.third_party.langchain.llms.openllm.OpenLLM._call', | |||
| return_value="abc") | |||
| MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise( | |||
| @@ -35,8 +34,6 @@ def test_is_credentials_valid_or_raise_valid(mocker): | |||
| def test_is_credentials_valid_or_raise_invalid(mocker): | |||
| mocker.patch('langchain.llms.openllm.OpenLLM._identifying_params', return_value=None) | |||
| # raise CredentialsValidateFailedError if credential is not in credentials | |||
| with pytest.raises(CredentialsValidateFailedError): | |||
| MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise( | |||