Você não pode selecionar mais de 25 tópicos Os tópicos devem começar com uma letra ou um número, podem incluir traços ('-') e podem ter até 35 caracteres.

streamable_open_ai.py 2.5KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. import os
  2. from langchain.callbacks.manager import Callbacks
  3. from langchain.schema import LLMResult
  4. from typing import Optional, List, Dict, Any, Mapping
  5. from langchain import OpenAI
  6. from pydantic import root_validator
  7. from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async
  8. class StreamableOpenAI(OpenAI):
  9. @root_validator()
  10. def validate_environment(cls, values: Dict) -> Dict:
  11. """Validate that api key and python package exists in environment."""
  12. try:
  13. import openai
  14. values["client"] = openai.Completion
  15. except ImportError:
  16. raise ValueError(
  17. "Could not import openai python package. "
  18. "Please install it with `pip install openai`."
  19. )
  20. if values["streaming"] and values["n"] > 1:
  21. raise ValueError("Cannot stream results when n > 1.")
  22. if values["streaming"] and values["best_of"] > 1:
  23. raise ValueError("Cannot stream results when best_of > 1.")
  24. return values
  25. @property
  26. def _invocation_params(self) -> Dict[str, Any]:
  27. return {**super()._invocation_params, **{
  28. "api_type": 'openai',
  29. "api_base": os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1"),
  30. "api_version": None,
  31. "api_key": self.openai_api_key,
  32. "organization": self.openai_organization if self.openai_organization else None,
  33. }}
  34. @property
  35. def _identifying_params(self) -> Mapping[str, Any]:
  36. return {**super()._identifying_params, **{
  37. "api_type": 'openai',
  38. "api_base": os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1"),
  39. "api_version": None,
  40. "api_key": self.openai_api_key,
  41. "organization": self.openai_organization if self.openai_organization else None,
  42. }}
  43. @handle_llm_exceptions
  44. def generate(
  45. self,
  46. prompts: List[str],
  47. stop: Optional[List[str]] = None,
  48. callbacks: Callbacks = None,
  49. **kwargs: Any,
  50. ) -> LLMResult:
  51. return super().generate(prompts, stop, callbacks, **kwargs)
  52. @handle_llm_exceptions_async
  53. async def agenerate(
  54. self,
  55. prompts: List[str],
  56. stop: Optional[List[str]] = None,
  57. callbacks: Callbacks = None,
  58. **kwargs: Any,
  59. ) -> LLMResult:
  60. return await super().agenerate(prompts, stop, callbacks, **kwargs)