소스 검색

normalize embedding (#974)

Co-authored-by: jyong <jyong@dify.ai>
tags/0.3.16
Jyong 2 년 전
부모
커밋
1fc57d7358
No account linked to committer's email address
1개의 변경된 파일8개의 추가작업 그리고 4개의 파일을 삭제
  1. 8
    4
      api/core/embedding/cached_embedding.py

+ 8
- 4
api/core/embedding/cached_embedding.py 파일 보기

@@ -1,6 +1,7 @@
import logging
from typing import List

import numpy as np
from langchain.embeddings.base import Embeddings
from sqlalchemy.exc import IntegrityError

@@ -32,14 +33,17 @@ class CacheEmbedding(Embeddings):
embedding_results = self._embeddings.client.embed_documents(embedding_queue_texts)
except Exception as ex:
raise self._embeddings.handle_exceptions(ex)

i = 0
normalized_embedding_results = []
for text in embedding_queue_texts:
hash = helper.generate_text_hash(text)

try:
embedding = Embedding(model_name=self._embeddings.name, hash=hash)
embedding.set_embedding(embedding_results[i])
vector = embedding_results[i]
normalized_embedding = (vector / np.linalg.norm(vector)).tolist()
normalized_embedding_results.append(normalized_embedding)
embedding.set_embedding(normalized_embedding)
db.session.add(embedding)
db.session.commit()
except IntegrityError:
@@ -51,7 +55,7 @@ class CacheEmbedding(Embeddings):
finally:
i += 1

text_embeddings.extend(embedding_results)
text_embeddings.extend(normalized_embedding_results)
return text_embeddings

def embed_query(self, text: str) -> List[float]:
@@ -64,6 +68,7 @@ class CacheEmbedding(Embeddings):

try:
embedding_results = self._embeddings.client.embed_query(text)
embedding_results = (embedding_results / np.linalg.norm(embedding_results)).tolist()
except Exception as ex:
raise self._embeddings.handle_exceptions(ex)

@@ -79,4 +84,3 @@ class CacheEmbedding(Embeddings):

return embedding_results



Loading…
취소
저장