Co-authored-by: jyong <jyong@dify.ai>tags/0.6.0-preview-workflow.2
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| from extensions.ext_redis import redis_client | from extensions.ext_redis import redis_client | ||||
| from libs import helper | from libs import helper | ||||
| from models.dataset import Embedding | |||||
| logger = logging.getLogger(__name__) | logger = logging.getLogger(__name__) | ||||
| def embed_documents(self, texts: list[str]) -> list[list[float]]: | def embed_documents(self, texts: list[str]) -> list[list[float]]: | ||||
| """Embed search docs in batches of 10.""" | """Embed search docs in batches of 10.""" | ||||
| text_embeddings = [] | |||||
| try: | |||||
| model_type_instance = cast(TextEmbeddingModel, self._model_instance.model_type_instance) | |||||
| model_schema = model_type_instance.get_model_schema(self._model_instance.model, self._model_instance.credentials) | |||||
| max_chunks = model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS] \ | |||||
| if model_schema and ModelPropertyKey.MAX_CHUNKS in model_schema.model_properties else 1 | |||||
| for i in range(0, len(texts), max_chunks): | |||||
| batch_texts = texts[i:i + max_chunks] | |||||
| embedding_result = self._model_instance.invoke_text_embedding( | |||||
| texts=batch_texts, | |||||
| user=self._user | |||||
| ) | |||||
| for vector in embedding_result.embeddings: | |||||
| try: | |||||
| normalized_embedding = (vector / np.linalg.norm(vector)).tolist() | |||||
| text_embeddings.append(normalized_embedding) | |||||
| except IntegrityError: | |||||
| db.session.rollback() | |||||
| except Exception as e: | |||||
| logging.exception('Failed to add embedding to redis') | |||||
| except Exception as ex: | |||||
| logger.error('Failed to embed documents: ', ex) | |||||
| raise ex | |||||
| # use doc embedding cache or store if not exists | |||||
| text_embeddings = [None for _ in range(len(texts))] | |||||
| embedding_queue_indices = [] | |||||
| for i, text in enumerate(texts): | |||||
| hash = helper.generate_text_hash(text) | |||||
| embedding = db.session.query(Embedding).filter_by(model_name=self._model_instance.model, | |||||
| hash=hash, | |||||
| provider_name=self._model_instance.provider).first() | |||||
| if embedding: | |||||
| text_embeddings[i] = embedding.get_embedding() | |||||
| else: | |||||
| embedding_queue_indices.append(i) | |||||
| if embedding_queue_indices: | |||||
| embedding_queue_texts = [texts[i] for i in embedding_queue_indices] | |||||
| embedding_queue_embeddings = [] | |||||
| try: | |||||
| model_type_instance = cast(TextEmbeddingModel, self._model_instance.model_type_instance) | |||||
| model_schema = model_type_instance.get_model_schema(self._model_instance.model, self._model_instance.credentials) | |||||
| max_chunks = model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS] \ | |||||
| if model_schema and ModelPropertyKey.MAX_CHUNKS in model_schema.model_properties else 1 | |||||
| for i in range(0, len(embedding_queue_texts), max_chunks): | |||||
| batch_texts = embedding_queue_texts[i:i + max_chunks] | |||||
| embedding_result = self._model_instance.invoke_text_embedding( | |||||
| texts=batch_texts, | |||||
| user=self._user | |||||
| ) | |||||
| for vector in embedding_result.embeddings: | |||||
| try: | |||||
| normalized_embedding = (vector / np.linalg.norm(vector)).tolist() | |||||
| embedding_queue_embeddings.append(normalized_embedding) | |||||
| except IntegrityError: | |||||
| db.session.rollback() | |||||
| except Exception as e: | |||||
| logging.exception('Failed transform embedding: ', e) | |||||
| for i, embedding in zip(embedding_queue_indices, embedding_queue_embeddings): | |||||
| text_embeddings[i] = embedding | |||||
| hash = helper.generate_text_hash(texts[i]) | |||||
| embedding_cache = Embedding(model_name=self._model_instance.model, | |||||
| hash=hash, | |||||
| provider_name=self._model_instance.provider) | |||||
| embedding_cache.set_embedding(embedding) | |||||
| db.session.add(embedding_cache) | |||||
| db.session.commit() | |||||
| except Exception as ex: | |||||
| db.session.rollback() | |||||
| logger.error('Failed to embed documents: ', ex) | |||||
| raise ex | |||||
| return text_embeddings | return text_embeddings | ||||
| if embedding: | if embedding: | ||||
| redis_client.expire(embedding_cache_key, 600) | redis_client.expire(embedding_cache_key, 600) | ||||
| return list(np.frombuffer(base64.b64decode(embedding), dtype="float")) | return list(np.frombuffer(base64.b64decode(embedding), dtype="float")) | ||||
| try: | try: | ||||
| embedding_result = self._model_instance.invoke_text_embedding( | embedding_result = self._model_instance.invoke_text_embedding( | ||||
| texts=[text], | texts=[text], |
| beat_schedule = { | beat_schedule = { | ||||
| 'clean_embedding_cache_task': { | 'clean_embedding_cache_task': { | ||||
| 'task': 'schedule.clean_embedding_cache_task.clean_embedding_cache_task', | 'task': 'schedule.clean_embedding_cache_task.clean_embedding_cache_task', | ||||
| 'schedule': timedelta(days=7), | |||||
| 'schedule': timedelta(days=1), | |||||
| }, | }, | ||||
| 'clean_unused_datasets_task': { | 'clean_unused_datasets_task': { | ||||
| 'task': 'schedule.clean_unused_datasets_task.clean_unused_datasets_task', | 'task': 'schedule.clean_unused_datasets_task.clean_unused_datasets_task', | ||||
| 'schedule': timedelta(minutes=3), | |||||
| 'schedule': timedelta(days=1), | |||||
| } | } | ||||
| } | } | ||||
| celery_app.conf.update( | celery_app.conf.update( |
| """add-embeddings-provider-name | |||||
| Revision ID: a8d7385a7b66 | |||||
| Revises: 17b5ab037c40 | |||||
| Create Date: 2024-04-02 12:17:22.641525 | |||||
| """ | |||||
| import sqlalchemy as sa | |||||
| from alembic import op | |||||
| # revision identifiers, used by Alembic. | |||||
| revision = 'a8d7385a7b66' | |||||
| down_revision = '17b5ab037c40' | |||||
| branch_labels = None | |||||
| depends_on = None | |||||
| def upgrade(): | |||||
| # ### commands auto generated by Alembic - please adjust! ### | |||||
| with op.batch_alter_table('embeddings', schema=None) as batch_op: | |||||
| batch_op.add_column(sa.Column('provider_name', sa.String(length=40), server_default=sa.text("''::character varying"), nullable=False)) | |||||
| batch_op.drop_constraint('embedding_hash_idx', type_='unique') | |||||
| batch_op.create_unique_constraint('embedding_hash_idx', ['model_name', 'hash', 'provider_name']) | |||||
| # ### end Alembic commands ### | |||||
| def downgrade(): | |||||
| # ### commands auto generated by Alembic - please adjust! ### | |||||
| with op.batch_alter_table('embeddings', schema=None) as batch_op: | |||||
| batch_op.drop_constraint('embedding_hash_idx', type_='unique') | |||||
| batch_op.create_unique_constraint('embedding_hash_idx', ['model_name', 'hash']) | |||||
| batch_op.drop_column('provider_name') | |||||
| # ### end Alembic commands ### |
| normalized_dataset_id = dataset_id.replace("-", "_") | normalized_dataset_id = dataset_id.replace("-", "_") | ||||
| return f'Vector_index_{normalized_dataset_id}_Node' | return f'Vector_index_{normalized_dataset_id}_Node' | ||||
| class DatasetProcessRule(db.Model): | class DatasetProcessRule(db.Model): | ||||
| __tablename__ = 'dataset_process_rules' | __tablename__ = 'dataset_process_rules' | ||||
| __table_args__ = ( | __table_args__ = ( | ||||
| id = db.Column(UUID, primary_key=True, server_default=db.text('uuid_generate_v4()')) | id = db.Column(UUID, primary_key=True, server_default=db.text('uuid_generate_v4()')) | ||||
| dataset_id = db.Column(UUID, nullable=False, unique=True) | dataset_id = db.Column(UUID, nullable=False, unique=True) | ||||
| keyword_table = db.Column(db.Text, nullable=False) | keyword_table = db.Column(db.Text, nullable=False) | ||||
| data_source_type = db.Column(db.String(255), nullable=False, server_default=db.text("'database'::character varying")) | |||||
| data_source_type = db.Column(db.String(255), nullable=False, | |||||
| server_default=db.text("'database'::character varying")) | |||||
| @property | @property | ||||
| def keyword_table_dict(self): | def keyword_table_dict(self): | ||||
| if isinstance(node_idxs, list): | if isinstance(node_idxs, list): | ||||
| dct[keyword] = set(node_idxs) | dct[keyword] = set(node_idxs) | ||||
| return dct | return dct | ||||
| # get dataset | # get dataset | ||||
| dataset = Dataset.query.filter_by( | dataset = Dataset.query.filter_by( | ||||
| id=self.dataset_id | id=self.dataset_id | ||||
| __tablename__ = 'embeddings' | __tablename__ = 'embeddings' | ||||
| __table_args__ = ( | __table_args__ = ( | ||||
| db.PrimaryKeyConstraint('id', name='embedding_pkey'), | db.PrimaryKeyConstraint('id', name='embedding_pkey'), | ||||
| db.UniqueConstraint('model_name', 'hash', name='embedding_hash_idx') | |||||
| db.UniqueConstraint('model_name', 'hash', 'provider_name', name='embedding_hash_idx') | |||||
| ) | ) | ||||
| id = db.Column(UUID, primary_key=True, server_default=db.text('uuid_generate_v4()')) | id = db.Column(UUID, primary_key=True, server_default=db.text('uuid_generate_v4()')) | ||||
| hash = db.Column(db.String(64), nullable=False) | hash = db.Column(db.String(64), nullable=False) | ||||
| embedding = db.Column(db.LargeBinary, nullable=False) | embedding = db.Column(db.LargeBinary, nullable=False) | ||||
| created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) | created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) | ||||
| provider_name = db.Column(db.String(40), nullable=False, | |||||
| server_default=db.text("''::character varying")) | |||||
| def set_embedding(self, embedding_data: list[float]): | def set_embedding(self, embedding_data: list[float]): | ||||
| self.embedding = pickle.dumps(embedding_data, protocol=pickle.HIGHEST_PROTOCOL) | self.embedding = pickle.dumps(embedding_data, protocol=pickle.HIGHEST_PROTOCOL) |