| @@ -1,3 +1,5 @@ | |||
| import base64 | |||
| import json | |||
| import logging | |||
| from typing import List, Optional | |||
| @@ -5,6 +7,8 @@ import numpy as np | |||
| from core.model_manager import ModelInstance | |||
| from extensions.ext_database import db | |||
| from langchain.embeddings.base import Embeddings | |||
| from extensions.ext_redis import redis_client | |||
| from libs import helper | |||
| from models.dataset import Embedding | |||
| from sqlalchemy.exc import IntegrityError | |||
| @@ -24,9 +28,12 @@ class CacheEmbedding(Embeddings): | |||
| embedding_queue_indices = [] | |||
| for i, text in enumerate(texts): | |||
| hash = helper.generate_text_hash(text) | |||
| embedding = db.session.query(Embedding).filter_by(model_name=self._model_instance.model, hash=hash).first() | |||
| embedding_cache_key = f'{self._model_instance.provider}_{self._model_instance.model}_{hash}' | |||
| embedding = redis_client.get(embedding_cache_key) | |||
| if embedding: | |||
| text_embeddings[i] = embedding.get_embedding() | |||
| redis_client.expire(embedding_cache_key, 3600) | |||
| text_embeddings[i] = list(np.frombuffer(base64.b64decode(embedding), dtype="float")) | |||
| else: | |||
| embedding_queue_indices.append(i) | |||
| @@ -46,18 +53,24 @@ class CacheEmbedding(Embeddings): | |||
| hash = helper.generate_text_hash(texts[indice]) | |||
| try: | |||
| embedding = Embedding(model_name=self._model_instance.model, hash=hash) | |||
| embedding_cache_key = f'{self._model_instance.provider}_{self._model_instance.model}_{hash}' | |||
| vector = embedding_results[i] | |||
| normalized_embedding = (vector / np.linalg.norm(vector)).tolist() | |||
| text_embeddings[indice] = normalized_embedding | |||
| embedding.set_embedding(normalized_embedding) | |||
| db.session.add(embedding) | |||
| db.session.commit() | |||
| # encode embedding to base64 | |||
| embedding_vector = np.array(normalized_embedding) | |||
| vector_bytes = embedding_vector.tobytes() | |||
| # Transform to Base64 | |||
| encoded_vector = base64.b64encode(vector_bytes) | |||
| # Transform to string | |||
| encoded_str = encoded_vector.decode("utf-8") | |||
| redis_client.setex(embedding_cache_key, 3600, encoded_str) | |||
| except IntegrityError: | |||
| db.session.rollback() | |||
| continue | |||
| except: | |||
| logging.exception('Failed to add embedding to db') | |||
| logging.exception('Failed to add embedding to redis') | |||
| continue | |||
| return text_embeddings | |||
| @@ -66,9 +79,12 @@ class CacheEmbedding(Embeddings): | |||
| """Embed query text.""" | |||
| # use doc embedding cache or store if not exists | |||
| hash = helper.generate_text_hash(text) | |||
| embedding = db.session.query(Embedding).filter_by(model_name=self._model_instance.model, hash=hash).first() | |||
| embedding_cache_key = f'{self._model_instance.provider}_{self._model_instance.model}_{hash}' | |||
| embedding = redis_client.get(embedding_cache_key) | |||
| if embedding: | |||
| return embedding.get_embedding() | |||
| redis_client.expire(embedding_cache_key, 3600) | |||
| return list(np.frombuffer(base64.b64decode(embedding), dtype="float")) | |||
| try: | |||
| embedding_result = self._model_instance.invoke_text_embedding( | |||
| @@ -82,13 +98,18 @@ class CacheEmbedding(Embeddings): | |||
| raise ex | |||
| try: | |||
| embedding = Embedding(model_name=self._model_instance.model, hash=hash) | |||
| embedding.set_embedding(embedding_results) | |||
| db.session.add(embedding) | |||
| db.session.commit() | |||
| # encode embedding to base64 | |||
| embedding_vector = np.array(embedding_results) | |||
| vector_bytes = embedding_vector.tobytes() | |||
| # Transform to Base64 | |||
| encoded_vector = base64.b64encode(vector_bytes) | |||
| # Transform to string | |||
| encoded_str = encoded_vector.decode("utf-8") | |||
| redis_client.setex(embedding_cache_key, 3600, encoded_str) | |||
| except IntegrityError: | |||
| db.session.rollback() | |||
| except: | |||
| logging.exception('Failed to add embedding to db') | |||
| logging.exception('Failed to add embedding to redis') | |||
| return embedding_results | |||