|
|
|
@@ -12,6 +12,7 @@ from core.rag.datasource.entity.embedding import Embeddings |
|
|
|
from extensions.ext_database import db |
|
|
|
from extensions.ext_redis import redis_client |
|
|
|
from libs import helper |
|
|
|
from models.dataset import Embedding |
|
|
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
@@ -23,32 +24,55 @@ class CacheEmbedding(Embeddings): |
|
|
|
|
|
|
|
def embed_documents(self, texts: list[str]) -> list[list[float]]: |
|
|
|
"""Embed search docs in batches of 10.""" |
|
|
|
text_embeddings = [] |
|
|
|
try: |
|
|
|
model_type_instance = cast(TextEmbeddingModel, self._model_instance.model_type_instance) |
|
|
|
model_schema = model_type_instance.get_model_schema(self._model_instance.model, self._model_instance.credentials) |
|
|
|
max_chunks = model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS] \ |
|
|
|
if model_schema and ModelPropertyKey.MAX_CHUNKS in model_schema.model_properties else 1 |
|
|
|
for i in range(0, len(texts), max_chunks): |
|
|
|
batch_texts = texts[i:i + max_chunks] |
|
|
|
|
|
|
|
embedding_result = self._model_instance.invoke_text_embedding( |
|
|
|
texts=batch_texts, |
|
|
|
user=self._user |
|
|
|
) |
|
|
|
|
|
|
|
for vector in embedding_result.embeddings: |
|
|
|
try: |
|
|
|
normalized_embedding = (vector / np.linalg.norm(vector)).tolist() |
|
|
|
text_embeddings.append(normalized_embedding) |
|
|
|
except IntegrityError: |
|
|
|
db.session.rollback() |
|
|
|
except Exception as e: |
|
|
|
logging.exception('Failed to add embedding to redis') |
|
|
|
|
|
|
|
except Exception as ex: |
|
|
|
logger.error('Failed to embed documents: ', ex) |
|
|
|
raise ex |
|
|
|
# use doc embedding cache or store if not exists |
|
|
|
text_embeddings = [None for _ in range(len(texts))] |
|
|
|
embedding_queue_indices = [] |
|
|
|
for i, text in enumerate(texts): |
|
|
|
hash = helper.generate_text_hash(text) |
|
|
|
embedding = db.session.query(Embedding).filter_by(model_name=self._model_instance.model, |
|
|
|
hash=hash, |
|
|
|
provider_name=self._model_instance.provider).first() |
|
|
|
if embedding: |
|
|
|
text_embeddings[i] = embedding.get_embedding() |
|
|
|
else: |
|
|
|
embedding_queue_indices.append(i) |
|
|
|
if embedding_queue_indices: |
|
|
|
embedding_queue_texts = [texts[i] for i in embedding_queue_indices] |
|
|
|
embedding_queue_embeddings = [] |
|
|
|
try: |
|
|
|
model_type_instance = cast(TextEmbeddingModel, self._model_instance.model_type_instance) |
|
|
|
model_schema = model_type_instance.get_model_schema(self._model_instance.model, self._model_instance.credentials) |
|
|
|
max_chunks = model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS] \ |
|
|
|
if model_schema and ModelPropertyKey.MAX_CHUNKS in model_schema.model_properties else 1 |
|
|
|
for i in range(0, len(embedding_queue_texts), max_chunks): |
|
|
|
batch_texts = embedding_queue_texts[i:i + max_chunks] |
|
|
|
|
|
|
|
embedding_result = self._model_instance.invoke_text_embedding( |
|
|
|
texts=batch_texts, |
|
|
|
user=self._user |
|
|
|
) |
|
|
|
|
|
|
|
for vector in embedding_result.embeddings: |
|
|
|
try: |
|
|
|
normalized_embedding = (vector / np.linalg.norm(vector)).tolist() |
|
|
|
embedding_queue_embeddings.append(normalized_embedding) |
|
|
|
except IntegrityError: |
|
|
|
db.session.rollback() |
|
|
|
except Exception as e: |
|
|
|
logging.exception('Failed transform embedding: ', e) |
|
|
|
for i, embedding in zip(embedding_queue_indices, embedding_queue_embeddings): |
|
|
|
text_embeddings[i] = embedding |
|
|
|
hash = helper.generate_text_hash(texts[i]) |
|
|
|
embedding_cache = Embedding(model_name=self._model_instance.model, |
|
|
|
hash=hash, |
|
|
|
provider_name=self._model_instance.provider) |
|
|
|
embedding_cache.set_embedding(embedding) |
|
|
|
db.session.add(embedding_cache) |
|
|
|
db.session.commit() |
|
|
|
except Exception as ex: |
|
|
|
db.session.rollback() |
|
|
|
logger.error('Failed to embed documents: ', ex) |
|
|
|
raise ex |
|
|
|
|
|
|
|
return text_embeddings |
|
|
|
|
|
|
|
@@ -61,8 +85,6 @@ class CacheEmbedding(Embeddings): |
|
|
|
if embedding: |
|
|
|
redis_client.expire(embedding_cache_key, 600) |
|
|
|
return list(np.frombuffer(base64.b64decode(embedding), dtype="float")) |
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
|
embedding_result = self._model_instance.invoke_text_embedding( |
|
|
|
texts=[text], |