| @@ -1,16 +1,14 @@ | |||
| import decimal | |||
| from functools import wraps | |||
| from typing import List, Optional, Any | |||
| from langchain import HuggingFaceHub | |||
| from langchain.callbacks.manager import Callbacks | |||
| from langchain.llms import HuggingFaceEndpoint | |||
| from langchain.schema import LLMResult | |||
| from core.model_providers.error import LLMBadRequestError | |||
| from core.model_providers.models.llm.base import BaseLLM | |||
| from core.model_providers.models.entity.message import PromptMessage, MessageType | |||
| from core.model_providers.models.entity.message import PromptMessage | |||
| from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs | |||
| from core.third_party.langchain.llms.huggingface_endpoint_llm import HuggingFaceEndpointLLM | |||
| class HuggingfaceHubModel(BaseLLM): | |||
| @@ -19,12 +17,12 @@ class HuggingfaceHubModel(BaseLLM): | |||
| def _init_client(self) -> Any: | |||
| provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs) | |||
| if self.credentials['huggingfacehub_api_type'] == 'inference_endpoints': | |||
| client = HuggingFaceEndpoint( | |||
| client = HuggingFaceEndpointLLM( | |||
| endpoint_url=self.credentials['huggingfacehub_endpoint_url'], | |||
| task='text2text-generation', | |||
| task=self.credentials['task_type'], | |||
| model_kwargs=provider_model_kwargs, | |||
| huggingfacehub_api_token=self.credentials['huggingfacehub_api_token'], | |||
| callbacks=self.callbacks, | |||
| callbacks=self.callbacks | |||
| ) | |||
| else: | |||
| client = HuggingFaceHub( | |||
| @@ -2,7 +2,6 @@ import json | |||
| from typing import Type | |||
| from huggingface_hub import HfApi | |||
| from langchain.llms import HuggingFaceEndpoint | |||
| from core.helper import encrypter | |||
| from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType | |||
| @@ -10,6 +9,7 @@ from core.model_providers.models.llm.huggingface_hub_model import HuggingfaceHub | |||
| from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError | |||
| from core.model_providers.models.base import BaseProviderModel | |||
| from core.third_party.langchain.llms.huggingface_endpoint_llm import HuggingFaceEndpointLLM | |||
| from models.provider import ProviderType | |||
| @@ -85,10 +85,16 @@ class HuggingfaceHubProvider(BaseModelProvider): | |||
| if 'huggingfacehub_endpoint_url' not in credentials: | |||
| raise CredentialsValidateFailedError('Hugging Face Hub Endpoint URL must be provided.') | |||
| if 'task_type' not in credentials: | |||
| raise CredentialsValidateFailedError('Task Type must be provided.') | |||
| if credentials['task_type'] not in ("text2text-generation", "text-generation", "summarization"): | |||
| raise CredentialsValidateFailedError('Task Type must be one of text2text-generation, text-generation, summarization.') | |||
| try: | |||
| llm = HuggingFaceEndpoint( | |||
| llm = HuggingFaceEndpointLLM( | |||
| endpoint_url=credentials['huggingfacehub_endpoint_url'], | |||
| task="text2text-generation", | |||
| task=credentials['task_type'], | |||
| model_kwargs={"temperature": 0.5, "max_new_tokens": 200}, | |||
| huggingfacehub_api_token=credentials['huggingfacehub_api_token'] | |||
| ) | |||
| @@ -160,6 +166,10 @@ class HuggingfaceHubProvider(BaseModelProvider): | |||
| } | |||
| credentials = json.loads(provider_model.encrypted_config) | |||
| if 'task_type' not in credentials: | |||
| credentials['task_type'] = 'text-generation' | |||
| if credentials['huggingfacehub_api_token']: | |||
| credentials['huggingfacehub_api_token'] = encrypter.decrypt_token( | |||
| self.provider.tenant_id, | |||
| @@ -0,0 +1,39 @@ | |||
| from typing import Dict | |||
| from langchain.llms import HuggingFaceEndpoint | |||
| from pydantic import Extra, root_validator | |||
| from langchain.utils import get_from_dict_or_env | |||
| class HuggingFaceEndpointLLM(HuggingFaceEndpoint): | |||
| """HuggingFace Endpoint models. | |||
| To use, you should have the ``huggingface_hub`` python package installed, and the | |||
| environment variable ``HUGGINGFACEHUB_API_TOKEN`` set with your API token, or pass | |||
| it as a named parameter to the constructor. | |||
| Only supports `text-generation` and `text2text-generation` for now. | |||
| Example: | |||
| .. code-block:: python | |||
| from langchain.llms import HuggingFaceEndpoint | |||
| endpoint_url = ( | |||
| "https://abcdefghijklmnop.us-east-1.aws.endpoints.huggingface.cloud" | |||
| ) | |||
| hf = HuggingFaceEndpoint( | |||
| endpoint_url=endpoint_url, | |||
| huggingfacehub_api_token="my-api-key" | |||
| ) | |||
| """ | |||
| @root_validator(allow_reuse=True) | |||
| def validate_environment(cls, values: Dict) -> Dict: | |||
| """Validate that api key and python package exists in environment.""" | |||
| huggingfacehub_api_token = get_from_dict_or_env( | |||
| values, "huggingfacehub_api_token", "HUGGINGFACEHUB_API_TOKEN" | |||
| ) | |||
| values["huggingfacehub_api_token"] = huggingfacehub_api_token | |||
| return values | |||
| @@ -17,7 +17,8 @@ HOSTED_INFERENCE_API_VALIDATE_CREDENTIAL = { | |||
| INFERENCE_ENDPOINTS_VALIDATE_CREDENTIAL = { | |||
| 'huggingfacehub_api_type': 'inference_endpoints', | |||
| 'huggingfacehub_api_token': 'valid_key', | |||
| 'huggingfacehub_endpoint_url': 'valid_url' | |||
| 'huggingfacehub_endpoint_url': 'valid_url', | |||
| 'task_type': 'text-generation' | |||
| } | |||
| def encrypt_side_effect(tenant_id, encrypt_key): | |||