| @@ -18,31 +18,30 @@ class CacheEmbedding(Embeddings): | |||
| def embed_documents(self, texts: List[str]) -> List[List[float]]: | |||
| """Embed search docs.""" | |||
| # use doc embedding cache or store if not exists | |||
| text_embeddings = [] | |||
| embedding_queue_texts = [] | |||
| for text in texts: | |||
| text_embeddings = [None for _ in range(len(texts))] | |||
| embedding_queue_indices = [] | |||
| for i, text in enumerate(texts): | |||
| hash = helper.generate_text_hash(text) | |||
| embedding = db.session.query(Embedding).filter_by(model_name=self._embeddings.name, hash=hash).first() | |||
| if embedding: | |||
| text_embeddings.append(embedding.get_embedding()) | |||
| text_embeddings[i] = embedding.get_embedding() | |||
| else: | |||
| embedding_queue_texts.append(text) | |||
| embedding_queue_indices.append(i) | |||
| if embedding_queue_texts: | |||
| if embedding_queue_indices: | |||
| try: | |||
| embedding_results = self._embeddings.client.embed_documents(embedding_queue_texts) | |||
| embedding_results = self._embeddings.client.embed_documents([texts[i] for i in embedding_queue_indices]) | |||
| except Exception as ex: | |||
| raise self._embeddings.handle_exceptions(ex) | |||
| i = 0 | |||
| normalized_embedding_results = [] | |||
| for text in embedding_queue_texts: | |||
| hash = helper.generate_text_hash(text) | |||
| for i, indice in enumerate(embedding_queue_indices): | |||
| hash = helper.generate_text_hash(texts[indice]) | |||
| try: | |||
| embedding = Embedding(model_name=self._embeddings.name, hash=hash) | |||
| vector = embedding_results[i] | |||
| normalized_embedding = (vector / np.linalg.norm(vector)).tolist() | |||
| normalized_embedding_results.append(normalized_embedding) | |||
| text_embeddings[indice] = normalized_embedding | |||
| embedding.set_embedding(normalized_embedding) | |||
| db.session.add(embedding) | |||
| db.session.commit() | |||
| @@ -52,10 +51,7 @@ class CacheEmbedding(Embeddings): | |||
| except: | |||
| logging.exception('Failed to add embedding to db') | |||
| continue | |||
| finally: | |||
| i += 1 | |||
| text_embeddings.extend(normalized_embedding_results) | |||
| return text_embeddings | |||
| def embed_query(self, text: str) -> List[float]: | |||