Browse Source

fix: hf hosted inference check (#1128)

tags/0.3.20
takatost 2 years ago
parent
commit
c4d8bdc3db
No account linked to committer's email address

+ 5
- 3
api/core/model_providers/models/llm/huggingface_hub_model.py View File

from typing import List, Optional, Any from typing import List, Optional, Any


from langchain import HuggingFaceHub
from langchain.callbacks.manager import Callbacks from langchain.callbacks.manager import Callbacks
from langchain.schema import LLMResult from langchain.schema import LLMResult


from core.model_providers.models.entity.message import PromptMessage from core.model_providers.models.entity.message import PromptMessage
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
from core.third_party.langchain.llms.huggingface_endpoint_llm import HuggingFaceEndpointLLM from core.third_party.langchain.llms.huggingface_endpoint_llm import HuggingFaceEndpointLLM
from core.third_party.langchain.llms.huggingface_hub_llm import HuggingFaceHubLLM




class HuggingfaceHubModel(BaseLLM): class HuggingfaceHubModel(BaseLLM):
streaming=streaming streaming=streaming
) )
else: else:
client = HuggingFaceHub(
client = HuggingFaceHubLLM(
repo_id=self.name, repo_id=self.name,
task=self.credentials['task_type'], task=self.credentials['task_type'],
model_kwargs=provider_model_kwargs, model_kwargs=provider_model_kwargs,
if 'baichuan' in self.name.lower(): if 'baichuan' in self.name.lower():
return False return False


return True
return True
else:
return False

+ 2
- 1
api/core/model_providers/providers/huggingface_hub_provider.py View File

raise CredentialsValidateFailedError('Task Type must be provided.') raise CredentialsValidateFailedError('Task Type must be provided.')


if credentials['task_type'] not in ("text2text-generation", "text-generation", "summarization"): if credentials['task_type'] not in ("text2text-generation", "text-generation", "summarization"):
raise CredentialsValidateFailedError('Task Type must be one of text2text-generation, text-generation, summarization.')
raise CredentialsValidateFailedError('Task Type must be one of text2text-generation, '
'text-generation, summarization.')


try: try:
llm = HuggingFaceEndpointLLM( llm = HuggingFaceEndpointLLM(

+ 62
- 0
api/core/third_party/langchain/llms/huggingface_hub_llm.py View File

from typing import Dict, Optional, List, Any

from huggingface_hub import HfApi, InferenceApi
from langchain import HuggingFaceHub
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.huggingface_hub import VALID_TASKS
from pydantic import root_validator

from langchain.utils import get_from_dict_or_env


class HuggingFaceHubLLM(HuggingFaceHub):
"""HuggingFaceHub models.

To use, you should have the ``huggingface_hub`` python package installed, and the
environment variable ``HUGGINGFACEHUB_API_TOKEN`` set with your API token, or pass
it as a named parameter to the constructor.

Only supports `text-generation`, `text2text-generation` and `summarization` for now.

Example:
.. code-block:: python

from langchain.llms import HuggingFaceHub
hf = HuggingFaceHub(repo_id="gpt2", huggingfacehub_api_token="my-api-key")
"""

@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
huggingfacehub_api_token = get_from_dict_or_env(
values, "huggingfacehub_api_token", "HUGGINGFACEHUB_API_TOKEN"
)
client = InferenceApi(
repo_id=values["repo_id"],
token=huggingfacehub_api_token,
task=values.get("task"),
)
client.options = {"wait_for_model": False, "use_gpu": False}
values["client"] = client
return values

def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
hfapi = HfApi(token=self.huggingfacehub_api_token)
model_info = hfapi.model_info(repo_id=self.repo_id)
if not model_info:
raise ValueError(f"Model {self.repo_id} not found.")

if 'inference' in model_info.cardData and not model_info.cardData['inference']:
raise ValueError(f"Inference API has been turned off for this model {self.repo_id}.")

if model_info.pipeline_tag not in VALID_TASKS:
raise ValueError(f"Model {self.repo_id} is not a valid task, "
f"must be one of {VALID_TASKS}.")

return super()._call(prompt, stop, run_manager, **kwargs)

Loading…
Cancel
Save