Co-authored-by: jyong <jyong@dify.ai>tags/0.5.11-fix1
| SSRF_PROXY_HTTPS_URL= | SSRF_PROXY_HTTPS_URL= | ||||
| BATCH_UPLOAD_LIMIT=10 | BATCH_UPLOAD_LIMIT=10 | ||||
| KEYWORD_DATA_SOURCE_TYPE=database |
| 'KEYWORD_STORE': 'jieba', | 'KEYWORD_STORE': 'jieba', | ||||
| 'BATCH_UPLOAD_LIMIT': 20, | 'BATCH_UPLOAD_LIMIT': 20, | ||||
| 'TOOL_ICON_CACHE_MAX_AGE': 3600, | 'TOOL_ICON_CACHE_MAX_AGE': 3600, | ||||
| 'KEYWORD_DATA_SOURCE_TYPE': 'database', | |||||
| } | } | ||||
| self.API_COMPRESSION_ENABLED = get_bool_env('API_COMPRESSION_ENABLED') | self.API_COMPRESSION_ENABLED = get_bool_env('API_COMPRESSION_ENABLED') | ||||
| self.TOOL_ICON_CACHE_MAX_AGE = get_env('TOOL_ICON_CACHE_MAX_AGE') | self.TOOL_ICON_CACHE_MAX_AGE = get_env('TOOL_ICON_CACHE_MAX_AGE') | ||||
| self.KEYWORD_DATA_SOURCE_TYPE = get_env('KEYWORD_DATA_SOURCE_TYPE') | |||||
| class CloudEditionConfig(Config): | class CloudEditionConfig(Config): | ||||
| from collections import defaultdict | from collections import defaultdict | ||||
| from typing import Any, Optional | from typing import Any, Optional | ||||
| from flask import current_app | |||||
| from pydantic import BaseModel | from pydantic import BaseModel | ||||
| from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler | from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler | ||||
| from core.rag.models.document import Document | from core.rag.models.document import Document | ||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| from extensions.ext_redis import redis_client | from extensions.ext_redis import redis_client | ||||
| from extensions.ext_storage import storage | |||||
| from models.dataset import Dataset, DatasetKeywordTable, DocumentSegment | from models.dataset import Dataset, DatasetKeywordTable, DocumentSegment | ||||
| if dataset_keyword_table: | if dataset_keyword_table: | ||||
| db.session.delete(dataset_keyword_table) | db.session.delete(dataset_keyword_table) | ||||
| db.session.commit() | db.session.commit() | ||||
| if dataset_keyword_table.data_source_type != 'database': | |||||
| file_key = 'keyword_files/' + self.dataset.tenant_id + '/' + self.dataset.id + '.txt' | |||||
| storage.delete(file_key) | |||||
| def _save_dataset_keyword_table(self, keyword_table): | def _save_dataset_keyword_table(self, keyword_table): | ||||
| keyword_table_dict = { | keyword_table_dict = { | ||||
| "table": keyword_table | "table": keyword_table | ||||
| } | } | ||||
| } | } | ||||
| self.dataset.dataset_keyword_table.keyword_table = json.dumps(keyword_table_dict, cls=SetEncoder) | |||||
| db.session.commit() | |||||
| dataset_keyword_table = self.dataset.dataset_keyword_table | |||||
| keyword_data_source_type = dataset_keyword_table.data_source_type | |||||
| if keyword_data_source_type == 'database': | |||||
| dataset_keyword_table.keyword_table = json.dumps(keyword_table_dict, cls=SetEncoder) | |||||
| db.session.commit() | |||||
| else: | |||||
| file_key = 'keyword_files/' + self.dataset.tenant_id + '/' + self.dataset.id + '.txt' | |||||
| if storage.exists(file_key): | |||||
| storage.delete(file_key) | |||||
| storage.save(file_key, json.dumps(keyword_table_dict, cls=SetEncoder).encode('utf-8')) | |||||
| def _get_dataset_keyword_table(self) -> Optional[dict]: | def _get_dataset_keyword_table(self) -> Optional[dict]: | ||||
| lock_name = 'keyword_indexing_lock_{}'.format(self.dataset.id) | lock_name = 'keyword_indexing_lock_{}'.format(self.dataset.id) | ||||
| with redis_client.lock(lock_name, timeout=20): | with redis_client.lock(lock_name, timeout=20): | ||||
| dataset_keyword_table = self.dataset.dataset_keyword_table | dataset_keyword_table = self.dataset.dataset_keyword_table | ||||
| if dataset_keyword_table: | if dataset_keyword_table: | ||||
| if dataset_keyword_table.keyword_table_dict: | |||||
| return dataset_keyword_table.keyword_table_dict['__data__']['table'] | |||||
| keyword_table_dict = dataset_keyword_table.keyword_table_dict | |||||
| if keyword_table_dict: | |||||
| return keyword_table_dict['__data__']['table'] | |||||
| else: | else: | ||||
| keyword_data_source_type = current_app.config['KEYWORD_DATA_SOURCE_TYPE'] | |||||
| dataset_keyword_table = DatasetKeywordTable( | dataset_keyword_table = DatasetKeywordTable( | ||||
| dataset_id=self.dataset.id, | dataset_id=self.dataset.id, | ||||
| keyword_table=json.dumps({ | |||||
| keyword_table='', | |||||
| data_source_type=keyword_data_source_type, | |||||
| ) | |||||
| if keyword_data_source_type == 'database': | |||||
| dataset_keyword_table.keyword_table = json.dumps({ | |||||
| '__type__': 'keyword_table', | '__type__': 'keyword_table', | ||||
| '__data__': { | '__data__': { | ||||
| "index_id": self.dataset.id, | "index_id": self.dataset.id, | ||||
| "table": {} | "table": {} | ||||
| } | } | ||||
| }, cls=SetEncoder) | }, cls=SetEncoder) | ||||
| ) | |||||
| db.session.add(dataset_keyword_table) | db.session.add(dataset_keyword_table) | ||||
| db.session.commit() | db.session.commit() | ||||
| return os.path.exists(filename) | return os.path.exists(filename) | ||||
| def delete(self, filename): | |||||
| if self.storage_type == 's3': | |||||
| self.client.delete_object(Bucket=self.bucket_name, Key=filename) | |||||
| elif self.storage_type == 'azure-blob': | |||||
| blob_container = self.client.get_container_client(container=self.bucket_name) | |||||
| blob_container.delete_blob(filename) | |||||
| else: | |||||
| if not self.folder or self.folder.endswith('/'): | |||||
| filename = self.folder + filename | |||||
| else: | |||||
| filename = self.folder + '/' + filename | |||||
| if os.path.exists(filename): | |||||
| os.remove(filename) | |||||
| storage = Storage() | storage = Storage() | ||||
| """add-keyworg-table-storage-type | |||||
| Revision ID: 17b5ab037c40 | |||||
| Revises: a8f9b3c45e4a | |||||
| Create Date: 2024-04-01 09:48:54.232201 | |||||
| """ | |||||
| import sqlalchemy as sa | |||||
| from alembic import op | |||||
| # revision identifiers, used by Alembic. | |||||
| revision = '17b5ab037c40' | |||||
| down_revision = 'a8f9b3c45e4a' | |||||
| branch_labels = None | |||||
| depends_on = None | |||||
| def upgrade(): | |||||
| # ### commands auto generated by Alembic - please adjust! ### | |||||
| with op.batch_alter_table('dataset_keyword_tables', schema=None) as batch_op: | |||||
| batch_op.add_column(sa.Column('data_source_type', sa.String(length=255), server_default=sa.text("'database'::character varying"), nullable=False)) | |||||
| # ### end Alembic commands ### | |||||
| def downgrade(): | |||||
| # ### commands auto generated by Alembic - please adjust! ### | |||||
| with op.batch_alter_table('dataset_keyword_tables', schema=None) as batch_op: | |||||
| batch_op.drop_column('data_source_type') | |||||
| # ### end Alembic commands ### |
| import json | import json | ||||
| import logging | |||||
| import pickle | import pickle | ||||
| from json import JSONDecodeError | from json import JSONDecodeError | ||||
| from sqlalchemy.dialects.postgresql import JSONB, UUID | from sqlalchemy.dialects.postgresql import JSONB, UUID | ||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| from extensions.ext_storage import storage | |||||
| from models.account import Account | from models.account import Account | ||||
| from models.model import App, UploadFile | from models.model import App, UploadFile | ||||
| id = db.Column(UUID, primary_key=True, server_default=db.text('uuid_generate_v4()')) | id = db.Column(UUID, primary_key=True, server_default=db.text('uuid_generate_v4()')) | ||||
| dataset_id = db.Column(UUID, nullable=False, unique=True) | dataset_id = db.Column(UUID, nullable=False, unique=True) | ||||
| keyword_table = db.Column(db.Text, nullable=False) | keyword_table = db.Column(db.Text, nullable=False) | ||||
| data_source_type = db.Column(db.String(255), nullable=False, server_default=db.text("'database'::character varying")) | |||||
| @property | @property | ||||
| def keyword_table_dict(self): | def keyword_table_dict(self): | ||||
| if isinstance(node_idxs, list): | if isinstance(node_idxs, list): | ||||
| dct[keyword] = set(node_idxs) | dct[keyword] = set(node_idxs) | ||||
| return dct | return dct | ||||
| return json.loads(self.keyword_table, cls=SetDecoder) if self.keyword_table else None | |||||
| # get dataset | |||||
| dataset = Dataset.query.filter_by( | |||||
| id=self.dataset_id | |||||
| ).first() | |||||
| if not dataset: | |||||
| return None | |||||
| if self.data_source_type == 'database': | |||||
| return json.loads(self.keyword_table, cls=SetDecoder) if self.keyword_table else None | |||||
| else: | |||||
| file_key = 'keyword_files/' + dataset.tenant_id + '/' + self.dataset_id + '.txt' | |||||
| try: | |||||
| keyword_table_text = storage.load_once(file_key) | |||||
| if keyword_table_text: | |||||
| return json.loads(keyword_table_text.decode('utf-8'), cls=SetDecoder) | |||||
| return None | |||||
| except Exception as e: | |||||
| logging.exception(str(e)) | |||||
| return None | |||||
| class Embedding(db.Model): | class Embedding(db.Model): |