| @@ -1,6 +1,7 @@ | |||
| import logging | |||
| from typing import List | |||
| import numpy as np | |||
| from langchain.embeddings.base import Embeddings | |||
| from sqlalchemy.exc import IntegrityError | |||
| @@ -32,14 +33,17 @@ class CacheEmbedding(Embeddings): | |||
| embedding_results = self._embeddings.client.embed_documents(embedding_queue_texts) | |||
| except Exception as ex: | |||
| raise self._embeddings.handle_exceptions(ex) | |||
| i = 0 | |||
| normalized_embedding_results = [] | |||
| for text in embedding_queue_texts: | |||
| hash = helper.generate_text_hash(text) | |||
| try: | |||
| embedding = Embedding(model_name=self._embeddings.name, hash=hash) | |||
| embedding.set_embedding(embedding_results[i]) | |||
| vector = embedding_results[i] | |||
| normalized_embedding = (vector / np.linalg.norm(vector)).tolist() | |||
| normalized_embedding_results.append(normalized_embedding) | |||
| embedding.set_embedding(normalized_embedding) | |||
| db.session.add(embedding) | |||
| db.session.commit() | |||
| except IntegrityError: | |||
| @@ -51,7 +55,7 @@ class CacheEmbedding(Embeddings): | |||
| finally: | |||
| i += 1 | |||
| text_embeddings.extend(embedding_results) | |||
| text_embeddings.extend(normalized_embedding_results) | |||
| return text_embeddings | |||
| def embed_query(self, text: str) -> List[float]: | |||
| @@ -64,6 +68,7 @@ class CacheEmbedding(Embeddings): | |||
| try: | |||
| embedding_results = self._embeddings.client.embed_query(text) | |||
| embedding_results = (embedding_results / np.linalg.norm(embedding_results)).tolist() | |||
| except Exception as ex: | |||
| raise self._embeddings.handle_exceptions(ex) | |||
| @@ -79,4 +84,3 @@ class CacheEmbedding(Embeddings): | |||
| return embedding_results | |||