| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200 | 
							- from typing import Optional, Any, List
 - 
 - import openai
 - from llama_index.embeddings.base import BaseEmbedding
 - from llama_index.embeddings.openai import OpenAIEmbeddingMode, OpenAIEmbeddingModelType, _QUERY_MODE_MODEL_DICT, \
 -     _TEXT_MODE_MODEL_DICT
 - from tenacity import wait_random_exponential, retry, stop_after_attempt
 - 
 - from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async
 - 
 - 
 - @retry(reraise=True, wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
 - def get_embedding(
 -         text: str,
 -         engine: Optional[str] = None,
 -         api_key: Optional[str] = None,
 -         **kwargs
 - ) -> List[float]:
 -     """Get embedding.
 - 
 -     NOTE: Copied from OpenAI's embedding utils:
 -     https://github.com/openai/openai-python/blob/main/openai/embeddings_utils.py
 - 
 -     Copied here to avoid importing unnecessary dependencies
 -     like matplotlib, plotly, scipy, sklearn.
 - 
 -     """
 -     text = text.replace("\n", " ")
 -     return openai.Embedding.create(input=[text], engine=engine, api_key=api_key, **kwargs)["data"][0]["embedding"]
 - 
 - 
 - @retry(reraise=True, wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
 - async def aget_embedding(text: str, engine: Optional[str] = None, api_key: Optional[str] = None, **kwargs) -> List[
 -     float]:
 -     """Asynchronously get embedding.
 - 
 -     NOTE: Copied from OpenAI's embedding utils:
 -     https://github.com/openai/openai-python/blob/main/openai/embeddings_utils.py
 - 
 -     Copied here to avoid importing unnecessary dependencies
 -     like matplotlib, plotly, scipy, sklearn.
 - 
 -     """
 -     # replace newlines, which can negatively affect performance.
 -     text = text.replace("\n", " ")
 - 
 -     return (await openai.Embedding.acreate(input=[text], engine=engine, api_key=api_key, **kwargs))["data"][0][
 -         "embedding"
 -     ]
 - 
 - 
 - @retry(reraise=True, wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
 - def get_embeddings(
 -         list_of_text: List[str],
 -         engine: Optional[str] = None,
 -         api_key: Optional[str] = None,
 -         **kwargs
 - ) -> List[List[float]]:
 -     """Get embeddings.
 - 
 -     NOTE: Copied from OpenAI's embedding utils:
 -     https://github.com/openai/openai-python/blob/main/openai/embeddings_utils.py
 - 
 -     Copied here to avoid importing unnecessary dependencies
 -     like matplotlib, plotly, scipy, sklearn.
 - 
 -     """
 -     assert len(list_of_text) <= 2048, "The batch size should not be larger than 2048."
 - 
 -     # replace newlines, which can negatively affect performance.
 -     list_of_text = [text.replace("\n", " ") for text in list_of_text]
 - 
 -     data = openai.Embedding.create(input=list_of_text, engine=engine, api_key=api_key, **kwargs).data
 -     data = sorted(data, key=lambda x: x["index"])  # maintain the same order as input.
 -     return [d["embedding"] for d in data]
 - 
 - 
 - @retry(reraise=True, wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
 - async def aget_embeddings(
 -         list_of_text: List[str], engine: Optional[str] = None, api_key: Optional[str] = None, **kwargs
 - ) -> List[List[float]]:
 -     """Asynchronously get embeddings.
 - 
 -     NOTE: Copied from OpenAI's embedding utils:
 -     https://github.com/openai/openai-python/blob/main/openai/embeddings_utils.py
 - 
 -     Copied here to avoid importing unnecessary dependencies
 -     like matplotlib, plotly, scipy, sklearn.
 - 
 -     """
 -     assert len(list_of_text) <= 2048, "The batch size should not be larger than 2048."
 - 
 -     # replace newlines, which can negatively affect performance.
 -     list_of_text = [text.replace("\n", " ") for text in list_of_text]
 - 
 -     data = (await openai.Embedding.acreate(input=list_of_text, engine=engine, api_key=api_key, **kwargs)).data
 -     data = sorted(data, key=lambda x: x["index"])  # maintain the same order as input.
 -     return [d["embedding"] for d in data]
 - 
 - 
 - class OpenAIEmbedding(BaseEmbedding):
 - 
 -     def __init__(
 -             self,
 -             mode: str = OpenAIEmbeddingMode.TEXT_SEARCH_MODE,
 -             model: str = OpenAIEmbeddingModelType.TEXT_EMBED_ADA_002,
 -             deployment_name: Optional[str] = None,
 -             openai_api_key: Optional[str] = None,
 -             **kwargs: Any,
 -     ) -> None:
 -         """Init params."""
 -         new_kwargs = {}
 - 
 -         if 'embed_batch_size' in kwargs:
 -             new_kwargs['embed_batch_size'] = kwargs['embed_batch_size']
 - 
 -         if 'tokenizer' in kwargs:
 -             new_kwargs['tokenizer'] = kwargs['tokenizer']
 - 
 -         super().__init__(**new_kwargs)
 -         self.mode = OpenAIEmbeddingMode(mode)
 -         self.model = OpenAIEmbeddingModelType(model)
 -         self.deployment_name = deployment_name
 -         self.openai_api_key = openai_api_key
 -         self.openai_api_type = kwargs.get('openai_api_type')
 -         self.openai_api_version = kwargs.get('openai_api_version')
 -         self.openai_api_base = kwargs.get('openai_api_base')
 - 
 -     @handle_llm_exceptions
 -     def _get_query_embedding(self, query: str) -> List[float]:
 -         """Get query embedding."""
 -         if self.deployment_name is not None:
 -             engine = self.deployment_name
 -         else:
 -             key = (self.mode, self.model)
 -             if key not in _QUERY_MODE_MODEL_DICT:
 -                 raise ValueError(f"Invalid mode, model combination: {key}")
 -             engine = _QUERY_MODE_MODEL_DICT[key]
 -         return get_embedding(query, engine=engine, api_key=self.openai_api_key,
 -                              api_type=self.openai_api_type, api_version=self.openai_api_version,
 -                              api_base=self.openai_api_base)
 - 
 -     def _get_text_embedding(self, text: str) -> List[float]:
 -         """Get text embedding."""
 -         if self.deployment_name is not None:
 -             engine = self.deployment_name
 -         else:
 -             key = (self.mode, self.model)
 -             if key not in _TEXT_MODE_MODEL_DICT:
 -                 raise ValueError(f"Invalid mode, model combination: {key}")
 -             engine = _TEXT_MODE_MODEL_DICT[key]
 -         return get_embedding(text, engine=engine, api_key=self.openai_api_key,
 -                              api_type=self.openai_api_type, api_version=self.openai_api_version,
 -                              api_base=self.openai_api_base)
 - 
 -     async def _aget_text_embedding(self, text: str) -> List[float]:
 -         """Asynchronously get text embedding."""
 -         if self.deployment_name is not None:
 -             engine = self.deployment_name
 -         else:
 -             key = (self.mode, self.model)
 -             if key not in _TEXT_MODE_MODEL_DICT:
 -                 raise ValueError(f"Invalid mode, model combination: {key}")
 -             engine = _TEXT_MODE_MODEL_DICT[key]
 -         return await aget_embedding(text, engine=engine, api_key=self.openai_api_key,
 -                                     api_type=self.openai_api_type, api_version=self.openai_api_version,
 -                                     api_base=self.openai_api_base)
 - 
 -     def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:
 -         """Get text embeddings.
 - 
 -         By default, this is a wrapper around _get_text_embedding.
 -         Can be overriden for batch queries.
 - 
 -         """
 -         if self.deployment_name is not None:
 -             engine = self.deployment_name
 -         else:
 -             key = (self.mode, self.model)
 -             if key not in _TEXT_MODE_MODEL_DICT:
 -                 raise ValueError(f"Invalid mode, model combination: {key}")
 -             engine = _TEXT_MODE_MODEL_DICT[key]
 -         embeddings = get_embeddings(texts, engine=engine, api_key=self.openai_api_key,
 -                                     api_type=self.openai_api_type, api_version=self.openai_api_version,
 -                                     api_base=self.openai_api_base)
 -         return embeddings
 - 
 -     async def _aget_text_embeddings(self, texts: List[str]) -> List[List[float]]:
 -         """Asynchronously get text embeddings."""
 -         if self.deployment_name is not None:
 -             engine = self.deployment_name
 -         else:
 -             key = (self.mode, self.model)
 -             if key not in _TEXT_MODE_MODEL_DICT:
 -                 raise ValueError(f"Invalid mode, model combination: {key}")
 -             engine = _TEXT_MODE_MODEL_DICT[key]
 -         embeddings = await aget_embeddings(texts, engine=engine, api_key=self.openai_api_key,
 -                                            api_type=self.openai_api_type, api_version=self.openai_api_version,
 -                                            api_base=self.openai_api_base)
 -         return embeddings
 
 
  |