|
|
|
@@ -0,0 +1,62 @@ |
|
|
|
from typing import Dict, Optional, List, Any |
|
|
|
|
|
|
|
from huggingface_hub import HfApi, InferenceApi |
|
|
|
from langchain import HuggingFaceHub |
|
|
|
from langchain.callbacks.manager import CallbackManagerForLLMRun |
|
|
|
from langchain.llms.huggingface_hub import VALID_TASKS |
|
|
|
from pydantic import root_validator |
|
|
|
|
|
|
|
from langchain.utils import get_from_dict_or_env |
|
|
|
|
|
|
|
|
|
|
|
class HuggingFaceHubLLM(HuggingFaceHub): |
|
|
|
"""HuggingFaceHub models. |
|
|
|
|
|
|
|
To use, you should have the ``huggingface_hub`` python package installed, and the |
|
|
|
environment variable ``HUGGINGFACEHUB_API_TOKEN`` set with your API token, or pass |
|
|
|
it as a named parameter to the constructor. |
|
|
|
|
|
|
|
Only supports `text-generation`, `text2text-generation` and `summarization` for now. |
|
|
|
|
|
|
|
Example: |
|
|
|
.. code-block:: python |
|
|
|
|
|
|
|
from langchain.llms import HuggingFaceHub |
|
|
|
hf = HuggingFaceHub(repo_id="gpt2", huggingfacehub_api_token="my-api-key") |
|
|
|
""" |
|
|
|
|
|
|
|
@root_validator() |
|
|
|
def validate_environment(cls, values: Dict) -> Dict: |
|
|
|
"""Validate that api key and python package exists in environment.""" |
|
|
|
huggingfacehub_api_token = get_from_dict_or_env( |
|
|
|
values, "huggingfacehub_api_token", "HUGGINGFACEHUB_API_TOKEN" |
|
|
|
) |
|
|
|
client = InferenceApi( |
|
|
|
repo_id=values["repo_id"], |
|
|
|
token=huggingfacehub_api_token, |
|
|
|
task=values.get("task"), |
|
|
|
) |
|
|
|
client.options = {"wait_for_model": False, "use_gpu": False} |
|
|
|
values["client"] = client |
|
|
|
return values |
|
|
|
|
|
|
|
def _call( |
|
|
|
self, |
|
|
|
prompt: str, |
|
|
|
stop: Optional[List[str]] = None, |
|
|
|
run_manager: Optional[CallbackManagerForLLMRun] = None, |
|
|
|
**kwargs: Any, |
|
|
|
) -> str: |
|
|
|
hfapi = HfApi(token=self.huggingfacehub_api_token) |
|
|
|
model_info = hfapi.model_info(repo_id=self.repo_id) |
|
|
|
if not model_info: |
|
|
|
raise ValueError(f"Model {self.repo_id} not found.") |
|
|
|
|
|
|
|
if 'inference' in model_info.cardData and not model_info.cardData['inference']: |
|
|
|
raise ValueError(f"Inference API has been turned off for this model {self.repo_id}.") |
|
|
|
|
|
|
|
if model_info.pipeline_tag not in VALID_TASKS: |
|
|
|
raise ValueError(f"Model {self.repo_id} is not a valid task, " |
|
|
|
f"must be one of {VALID_TASKS}.") |
|
|
|
|
|
|
|
return super()._call(prompt, stop, run_manager, **kwargs) |