| from typing import List, Optional, Any | from typing import List, Optional, Any | ||||
| from langchain import HuggingFaceHub | |||||
| from langchain.callbacks.manager import Callbacks | from langchain.callbacks.manager import Callbacks | ||||
| from langchain.schema import LLMResult | from langchain.schema import LLMResult | ||||
| from core.model_providers.models.entity.message import PromptMessage | from core.model_providers.models.entity.message import PromptMessage | ||||
| from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs | from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs | ||||
| from core.third_party.langchain.llms.huggingface_endpoint_llm import HuggingFaceEndpointLLM | from core.third_party.langchain.llms.huggingface_endpoint_llm import HuggingFaceEndpointLLM | ||||
| from core.third_party.langchain.llms.huggingface_hub_llm import HuggingFaceHubLLM | |||||
| class HuggingfaceHubModel(BaseLLM): | class HuggingfaceHubModel(BaseLLM): | ||||
| streaming=streaming | streaming=streaming | ||||
| ) | ) | ||||
| else: | else: | ||||
| client = HuggingFaceHub( | |||||
| client = HuggingFaceHubLLM( | |||||
| repo_id=self.name, | repo_id=self.name, | ||||
| task=self.credentials['task_type'], | task=self.credentials['task_type'], | ||||
| model_kwargs=provider_model_kwargs, | model_kwargs=provider_model_kwargs, | ||||
| if 'baichuan' in self.name.lower(): | if 'baichuan' in self.name.lower(): | ||||
| return False | return False | ||||
| return True | |||||
| return True | |||||
| else: | |||||
| return False |
| raise CredentialsValidateFailedError('Task Type must be provided.') | raise CredentialsValidateFailedError('Task Type must be provided.') | ||||
| if credentials['task_type'] not in ("text2text-generation", "text-generation", "summarization"): | if credentials['task_type'] not in ("text2text-generation", "text-generation", "summarization"): | ||||
| raise CredentialsValidateFailedError('Task Type must be one of text2text-generation, text-generation, summarization.') | |||||
| raise CredentialsValidateFailedError('Task Type must be one of text2text-generation, ' | |||||
| 'text-generation, summarization.') | |||||
| try: | try: | ||||
| llm = HuggingFaceEndpointLLM( | llm = HuggingFaceEndpointLLM( |
| from typing import Dict, Optional, List, Any | |||||
| from huggingface_hub import HfApi, InferenceApi | |||||
| from langchain import HuggingFaceHub | |||||
| from langchain.callbacks.manager import CallbackManagerForLLMRun | |||||
| from langchain.llms.huggingface_hub import VALID_TASKS | |||||
| from pydantic import root_validator | |||||
| from langchain.utils import get_from_dict_or_env | |||||
| class HuggingFaceHubLLM(HuggingFaceHub): | |||||
| """HuggingFaceHub models. | |||||
| To use, you should have the ``huggingface_hub`` python package installed, and the | |||||
| environment variable ``HUGGINGFACEHUB_API_TOKEN`` set with your API token, or pass | |||||
| it as a named parameter to the constructor. | |||||
| Only supports `text-generation`, `text2text-generation` and `summarization` for now. | |||||
| Example: | |||||
| .. code-block:: python | |||||
| from langchain.llms import HuggingFaceHub | |||||
| hf = HuggingFaceHub(repo_id="gpt2", huggingfacehub_api_token="my-api-key") | |||||
| """ | |||||
| @root_validator() | |||||
| def validate_environment(cls, values: Dict) -> Dict: | |||||
| """Validate that api key and python package exists in environment.""" | |||||
| huggingfacehub_api_token = get_from_dict_or_env( | |||||
| values, "huggingfacehub_api_token", "HUGGINGFACEHUB_API_TOKEN" | |||||
| ) | |||||
| client = InferenceApi( | |||||
| repo_id=values["repo_id"], | |||||
| token=huggingfacehub_api_token, | |||||
| task=values.get("task"), | |||||
| ) | |||||
| client.options = {"wait_for_model": False, "use_gpu": False} | |||||
| values["client"] = client | |||||
| return values | |||||
| def _call( | |||||
| self, | |||||
| prompt: str, | |||||
| stop: Optional[List[str]] = None, | |||||
| run_manager: Optional[CallbackManagerForLLMRun] = None, | |||||
| **kwargs: Any, | |||||
| ) -> str: | |||||
| hfapi = HfApi(token=self.huggingfacehub_api_token) | |||||
| model_info = hfapi.model_info(repo_id=self.repo_id) | |||||
| if not model_info: | |||||
| raise ValueError(f"Model {self.repo_id} not found.") | |||||
| if 'inference' in model_info.cardData and not model_info.cardData['inference']: | |||||
| raise ValueError(f"Inference API has been turned off for this model {self.repo_id}.") | |||||
| if model_info.pipeline_tag not in VALID_TASKS: | |||||
| raise ValueError(f"Model {self.repo_id} is not a valid task, " | |||||
| f"must be one of {VALID_TASKS}.") | |||||
| return super()._call(prompt, stop, run_manager, **kwargs) |