Co-authored-by: Nicolas <nicolascamara29@gmail.com> Co-authored-by: chenhe <guchenhe@gmail.com> Co-authored-by: takatost <takatost@gmail.com>tags/0.6.11
| @@ -215,4 +215,5 @@ WORKFLOW_MAX_EXECUTION_TIME=1200 | |||
| WORKFLOW_CALL_MAX_DEPTH=5 | |||
| # App configuration | |||
| APP_MAX_EXECUTION_TIME=1200 | |||
| APP_MAX_EXECUTION_TIME=1200 | |||
| @@ -29,13 +29,13 @@ from .app import ( | |||
| ) | |||
| # Import auth controllers | |||
| from .auth import activate, data_source_oauth, login, oauth | |||
| from .auth import activate, data_source_bearer_auth, data_source_oauth, login, oauth | |||
| # Import billing controllers | |||
| from .billing import billing | |||
| # Import datasets controllers | |||
| from .datasets import data_source, datasets, datasets_document, datasets_segments, file, hit_testing | |||
| from .datasets import data_source, datasets, datasets_document, datasets_segments, file, hit_testing, website | |||
| # Import explore controllers | |||
| from .explore import ( | |||
| @@ -0,0 +1,67 @@ | |||
| from flask_login import current_user | |||
| from flask_restful import Resource, reqparse | |||
| from werkzeug.exceptions import Forbidden | |||
| from controllers.console import api | |||
| from controllers.console.auth.error import ApiKeyAuthFailedError | |||
| from libs.login import login_required | |||
| from services.auth.api_key_auth_service import ApiKeyAuthService | |||
| from ..setup import setup_required | |||
| from ..wraps import account_initialization_required | |||
| class ApiKeyAuthDataSource(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| def get(self): | |||
| # The role of the current user in the table must be admin or owner | |||
| if not current_user.is_admin_or_owner: | |||
| raise Forbidden() | |||
| data_source_api_key_bindings = ApiKeyAuthService.get_provider_auth_list(current_user.current_tenant_id) | |||
| if data_source_api_key_bindings: | |||
| return { | |||
| 'settings': [data_source_api_key_binding.to_dict() for data_source_api_key_binding in | |||
| data_source_api_key_bindings]} | |||
| return {'settings': []} | |||
| class ApiKeyAuthDataSourceBinding(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| def post(self): | |||
| # The role of the current user in the table must be admin or owner | |||
| if not current_user.is_admin_or_owner: | |||
| raise Forbidden() | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('category', type=str, required=True, nullable=False, location='json') | |||
| parser.add_argument('provider', type=str, required=True, nullable=False, location='json') | |||
| parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json') | |||
| args = parser.parse_args() | |||
| ApiKeyAuthService.validate_api_key_auth_args(args) | |||
| try: | |||
| ApiKeyAuthService.create_provider_auth(current_user.current_tenant_id, args) | |||
| except Exception as e: | |||
| raise ApiKeyAuthFailedError(str(e)) | |||
| return {'result': 'success'}, 200 | |||
| class ApiKeyAuthDataSourceBindingDelete(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| def delete(self, binding_id): | |||
| # The role of the current user in the table must be admin or owner | |||
| if not current_user.is_admin_or_owner: | |||
| raise Forbidden() | |||
| ApiKeyAuthService.delete_provider_auth(current_user.current_tenant_id, binding_id) | |||
| return {'result': 'success'}, 200 | |||
| api.add_resource(ApiKeyAuthDataSource, '/api-key-auth/data-source') | |||
| api.add_resource(ApiKeyAuthDataSourceBinding, '/api-key-auth/data-source/binding') | |||
| api.add_resource(ApiKeyAuthDataSourceBindingDelete, '/api-key-auth/data-source/<uuid:binding_id>') | |||
| @@ -0,0 +1,7 @@ | |||
| from libs.exception import BaseHTTPException | |||
| class ApiKeyAuthFailedError(BaseHTTPException): | |||
| error_code = 'auth_failed' | |||
| description = "{message}" | |||
| code = 500 | |||
| @@ -16,7 +16,7 @@ from extensions.ext_database import db | |||
| from fields.data_source_fields import integrate_list_fields, integrate_notion_info_list_fields | |||
| from libs.login import login_required | |||
| from models.dataset import Document | |||
| from models.source import DataSourceBinding | |||
| from models.source import DataSourceOauthBinding | |||
| from services.dataset_service import DatasetService, DocumentService | |||
| from tasks.document_indexing_sync_task import document_indexing_sync_task | |||
| @@ -29,9 +29,9 @@ class DataSourceApi(Resource): | |||
| @marshal_with(integrate_list_fields) | |||
| def get(self): | |||
| # get workspace data source integrates | |||
| data_source_integrates = db.session.query(DataSourceBinding).filter( | |||
| DataSourceBinding.tenant_id == current_user.current_tenant_id, | |||
| DataSourceBinding.disabled == False | |||
| data_source_integrates = db.session.query(DataSourceOauthBinding).filter( | |||
| DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, | |||
| DataSourceOauthBinding.disabled == False | |||
| ).all() | |||
| base_url = request.url_root.rstrip('/') | |||
| @@ -71,7 +71,7 @@ class DataSourceApi(Resource): | |||
| def patch(self, binding_id, action): | |||
| binding_id = str(binding_id) | |||
| action = str(action) | |||
| data_source_binding = DataSourceBinding.query.filter_by( | |||
| data_source_binding = DataSourceOauthBinding.query.filter_by( | |||
| id=binding_id | |||
| ).first() | |||
| if data_source_binding is None: | |||
| @@ -124,7 +124,7 @@ class DataSourceNotionListApi(Resource): | |||
| data_source_info = json.loads(document.data_source_info) | |||
| exist_page_ids.append(data_source_info['notion_page_id']) | |||
| # get all authorized pages | |||
| data_source_bindings = DataSourceBinding.query.filter_by( | |||
| data_source_bindings = DataSourceOauthBinding.query.filter_by( | |||
| tenant_id=current_user.current_tenant_id, | |||
| provider='notion', | |||
| disabled=False | |||
| @@ -163,12 +163,12 @@ class DataSourceNotionApi(Resource): | |||
| def get(self, workspace_id, page_id, page_type): | |||
| workspace_id = str(workspace_id) | |||
| page_id = str(page_id) | |||
| data_source_binding = DataSourceBinding.query.filter( | |||
| data_source_binding = DataSourceOauthBinding.query.filter( | |||
| db.and_( | |||
| DataSourceBinding.tenant_id == current_user.current_tenant_id, | |||
| DataSourceBinding.provider == 'notion', | |||
| DataSourceBinding.disabled == False, | |||
| DataSourceBinding.source_info['workspace_id'] == f'"{workspace_id}"' | |||
| DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, | |||
| DataSourceOauthBinding.provider == 'notion', | |||
| DataSourceOauthBinding.disabled == False, | |||
| DataSourceOauthBinding.source_info['workspace_id'] == f'"{workspace_id}"' | |||
| ) | |||
| ).first() | |||
| if not data_source_binding: | |||
| @@ -315,6 +315,22 @@ class DatasetIndexingEstimateApi(Resource): | |||
| document_model=args['doc_form'] | |||
| ) | |||
| extract_settings.append(extract_setting) | |||
| elif args['info_list']['data_source_type'] == 'website_crawl': | |||
| website_info_list = args['info_list']['website_info_list'] | |||
| for url in website_info_list['urls']: | |||
| extract_setting = ExtractSetting( | |||
| datasource_type="website_crawl", | |||
| website_info={ | |||
| "provider": website_info_list['provider'], | |||
| "job_id": website_info_list['job_id'], | |||
| "url": url, | |||
| "tenant_id": current_user.current_tenant_id, | |||
| "mode": 'crawl', | |||
| "only_main_content": website_info_list['only_main_content'] | |||
| }, | |||
| document_model=args['doc_form'] | |||
| ) | |||
| extract_settings.append(extract_setting) | |||
| else: | |||
| raise ValueError('Data source type not support') | |||
| indexing_runner = IndexingRunner() | |||
| @@ -519,6 +535,7 @@ class DatasetRetrievalSettingMockApi(Resource): | |||
| raise ValueError(f"Unsupported vector db type {vector_type}.") | |||
| class DatasetErrorDocs(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @@ -465,6 +465,20 @@ class DocumentBatchIndexingEstimateApi(DocumentResource): | |||
| document_model=document.doc_form | |||
| ) | |||
| extract_settings.append(extract_setting) | |||
| elif document.data_source_type == 'website_crawl': | |||
| extract_setting = ExtractSetting( | |||
| datasource_type="website_crawl", | |||
| website_info={ | |||
| "provider": data_source_info['provider'], | |||
| "job_id": data_source_info['job_id'], | |||
| "url": data_source_info['url'], | |||
| "tenant_id": current_user.current_tenant_id, | |||
| "mode": data_source_info['mode'], | |||
| "only_main_content": data_source_info['only_main_content'] | |||
| }, | |||
| document_model=document.doc_form | |||
| ) | |||
| extract_settings.append(extract_setting) | |||
| else: | |||
| raise ValueError('Data source type not support') | |||
| @@ -952,6 +966,33 @@ class DocumentRenameApi(DocumentResource): | |||
| return document | |||
| class WebsiteDocumentSyncApi(DocumentResource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| def get(self, dataset_id, document_id): | |||
| """sync website document.""" | |||
| dataset_id = str(dataset_id) | |||
| dataset = DatasetService.get_dataset(dataset_id) | |||
| if not dataset: | |||
| raise NotFound('Dataset not found.') | |||
| document_id = str(document_id) | |||
| document = DocumentService.get_document(dataset.id, document_id) | |||
| if not document: | |||
| raise NotFound('Document not found.') | |||
| if document.tenant_id != current_user.current_tenant_id: | |||
| raise Forbidden('No permission.') | |||
| if document.data_source_type != 'website_crawl': | |||
| raise ValueError('Document is not a website document.') | |||
| # 403 if document is archived | |||
| if DocumentService.check_archived(document): | |||
| raise ArchivedDocumentImmutableError() | |||
| # sync document | |||
| DocumentService.sync_website_document(dataset_id, document) | |||
| return {'result': 'success'}, 200 | |||
| api.add_resource(GetProcessRuleApi, '/datasets/process-rule') | |||
| api.add_resource(DatasetDocumentListApi, | |||
| '/datasets/<uuid:dataset_id>/documents') | |||
| @@ -980,3 +1021,5 @@ api.add_resource(DocumentRecoverApi, '/datasets/<uuid:dataset_id>/documents/<uui | |||
| api.add_resource(DocumentRetryApi, '/datasets/<uuid:dataset_id>/retry') | |||
| api.add_resource(DocumentRenameApi, | |||
| '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/rename') | |||
| api.add_resource(WebsiteDocumentSyncApi, '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/website-sync') | |||
| @@ -73,6 +73,12 @@ class InvalidMetadataError(BaseHTTPException): | |||
| code = 400 | |||
| class WebsiteCrawlError(BaseHTTPException): | |||
| error_code = 'crawl_failed' | |||
| description = "{message}" | |||
| code = 500 | |||
| class DatasetInUseError(BaseHTTPException): | |||
| error_code = 'dataset_in_use' | |||
| description = "The dataset is being used by some apps. Please remove the dataset from the apps before deleting it." | |||
| @@ -0,0 +1,49 @@ | |||
| from flask_restful import Resource, reqparse | |||
| from controllers.console import api | |||
| from controllers.console.datasets.error import WebsiteCrawlError | |||
| from controllers.console.setup import setup_required | |||
| from controllers.console.wraps import account_initialization_required | |||
| from libs.login import login_required | |||
| from services.website_service import WebsiteService | |||
| class WebsiteCrawlApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| def post(self): | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('provider', type=str, choices=['firecrawl'], | |||
| required=True, nullable=True, location='json') | |||
| parser.add_argument('url', type=str, required=True, nullable=True, location='json') | |||
| parser.add_argument('options', type=dict, required=True, nullable=True, location='json') | |||
| args = parser.parse_args() | |||
| WebsiteService.document_create_args_validate(args) | |||
| # crawl url | |||
| try: | |||
| result = WebsiteService.crawl_url(args) | |||
| except Exception as e: | |||
| raise WebsiteCrawlError(str(e)) | |||
| return result, 200 | |||
| class WebsiteCrawlStatusApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| def get(self, job_id: str): | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('provider', type=str, choices=['firecrawl'], required=True, location='args') | |||
| args = parser.parse_args() | |||
| # get crawl status | |||
| try: | |||
| result = WebsiteService.get_crawl_status(job_id, args['provider']) | |||
| except Exception as e: | |||
| raise WebsiteCrawlError(str(e)) | |||
| return result, 200 | |||
| api.add_resource(WebsiteCrawlApi, '/website/crawl') | |||
| api.add_resource(WebsiteCrawlStatusApi, '/website/crawl/status/<string:job_id>') | |||
| @@ -339,7 +339,7 @@ class IndexingRunner: | |||
| def _extract(self, index_processor: BaseIndexProcessor, dataset_document: DatasetDocument, process_rule: dict) \ | |||
| -> list[Document]: | |||
| # load file | |||
| if dataset_document.data_source_type not in ["upload_file", "notion_import"]: | |||
| if dataset_document.data_source_type not in ["upload_file", "notion_import", "website_crawl"]: | |||
| return [] | |||
| data_source_info = dataset_document.data_source_info_dict | |||
| @@ -375,6 +375,23 @@ class IndexingRunner: | |||
| document_model=dataset_document.doc_form | |||
| ) | |||
| text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule['mode']) | |||
| elif dataset_document.data_source_type == 'website_crawl': | |||
| if (not data_source_info or 'provider' not in data_source_info | |||
| or 'url' not in data_source_info or 'job_id' not in data_source_info): | |||
| raise ValueError("no website import info found") | |||
| extract_setting = ExtractSetting( | |||
| datasource_type="website_crawl", | |||
| website_info={ | |||
| "provider": data_source_info['provider'], | |||
| "job_id": data_source_info['job_id'], | |||
| "tenant_id": dataset_document.tenant_id, | |||
| "url": data_source_info['url'], | |||
| "mode": data_source_info['mode'], | |||
| "only_main_content": data_source_info['only_main_content'] | |||
| }, | |||
| document_model=dataset_document.doc_form | |||
| ) | |||
| text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule['mode']) | |||
| # update document status to splitting | |||
| self._update_document_index_status( | |||
| document_id=dataset_document.id, | |||
| @@ -124,7 +124,7 @@ class TogetherAILargeLanguageModel(OAIAPICompatLargeLanguageModel): | |||
| default=float(credentials.get('presence_penalty', 0)), | |||
| min=-2, | |||
| max=2 | |||
| ) | |||
| ), | |||
| ], | |||
| pricing=PriceConfig( | |||
| input=Decimal(cred_with_endpoint.get('input_price', 0)), | |||
| @@ -4,3 +4,4 @@ from enum import Enum | |||
| class DatasourceType(Enum): | |||
| FILE = "upload_file" | |||
| NOTION = "notion_import" | |||
| WEBSITE = "website_crawl" | |||
| @@ -1,3 +1,5 @@ | |||
| from typing import Optional | |||
| from pydantic import BaseModel, ConfigDict | |||
| from models.dataset import Document | |||
| @@ -19,14 +21,33 @@ class NotionInfo(BaseModel): | |||
| super().__init__(**data) | |||
| class WebsiteInfo(BaseModel): | |||
| """ | |||
| website import info. | |||
| """ | |||
| provider: str | |||
| job_id: str | |||
| url: str | |||
| mode: str | |||
| tenant_id: str | |||
| only_main_content: bool = False | |||
| class Config: | |||
| arbitrary_types_allowed = True | |||
| def __init__(self, **data) -> None: | |||
| super().__init__(**data) | |||
| class ExtractSetting(BaseModel): | |||
| """ | |||
| Model class for provider response. | |||
| """ | |||
| datasource_type: str | |||
| upload_file: UploadFile = None | |||
| notion_info: NotionInfo = None | |||
| document_model: str = None | |||
| upload_file: Optional[UploadFile] | |||
| notion_info: Optional[NotionInfo] | |||
| website_info: Optional[WebsiteInfo] | |||
| document_model: Optional[str] | |||
| model_config = ConfigDict(arbitrary_types_allowed=True) | |||
| def __init__(self, **data) -> None: | |||
| @@ -11,6 +11,7 @@ from core.rag.extractor.csv_extractor import CSVExtractor | |||
| from core.rag.extractor.entity.datasource_type import DatasourceType | |||
| from core.rag.extractor.entity.extract_setting import ExtractSetting | |||
| from core.rag.extractor.excel_extractor import ExcelExtractor | |||
| from core.rag.extractor.firecrawl.firecrawl_web_extractor import FirecrawlWebExtractor | |||
| from core.rag.extractor.html_extractor import HtmlExtractor | |||
| from core.rag.extractor.markdown_extractor import MarkdownExtractor | |||
| from core.rag.extractor.notion_extractor import NotionExtractor | |||
| @@ -154,5 +155,17 @@ class ExtractProcessor: | |||
| tenant_id=extract_setting.notion_info.tenant_id, | |||
| ) | |||
| return extractor.extract() | |||
| elif extract_setting.datasource_type == DatasourceType.WEBSITE.value: | |||
| if extract_setting.website_info.provider == 'firecrawl': | |||
| extractor = FirecrawlWebExtractor( | |||
| url=extract_setting.website_info.url, | |||
| job_id=extract_setting.website_info.job_id, | |||
| tenant_id=extract_setting.website_info.tenant_id, | |||
| mode=extract_setting.website_info.mode, | |||
| only_main_content=extract_setting.website_info.only_main_content | |||
| ) | |||
| return extractor.extract() | |||
| else: | |||
| raise ValueError(f"Unsupported website provider: {extract_setting.website_info.provider}") | |||
| else: | |||
| raise ValueError(f"Unsupported datasource type: {extract_setting.datasource_type}") | |||
| @@ -0,0 +1,132 @@ | |||
| import json | |||
| import time | |||
| import requests | |||
| from extensions.ext_storage import storage | |||
| class FirecrawlApp: | |||
| def __init__(self, api_key=None, base_url=None): | |||
| self.api_key = api_key | |||
| self.base_url = base_url or 'https://api.firecrawl.dev' | |||
| if self.api_key is None and self.base_url == 'https://api.firecrawl.dev': | |||
| raise ValueError('No API key provided') | |||
| def scrape_url(self, url, params=None) -> dict: | |||
| headers = { | |||
| 'Content-Type': 'application/json', | |||
| 'Authorization': f'Bearer {self.api_key}' | |||
| } | |||
| json_data = {'url': url} | |||
| if params: | |||
| json_data.update(params) | |||
| response = requests.post( | |||
| f'{self.base_url}/v0/scrape', | |||
| headers=headers, | |||
| json=json_data | |||
| ) | |||
| if response.status_code == 200: | |||
| response = response.json() | |||
| if response['success'] == True: | |||
| data = response['data'] | |||
| return { | |||
| 'title': data.get('metadata').get('title'), | |||
| 'description': data.get('metadata').get('description'), | |||
| 'source_url': data.get('metadata').get('sourceURL'), | |||
| 'markdown': data.get('markdown') | |||
| } | |||
| else: | |||
| raise Exception(f'Failed to scrape URL. Error: {response["error"]}') | |||
| elif response.status_code in [402, 409, 500]: | |||
| error_message = response.json().get('error', 'Unknown error occurred') | |||
| raise Exception(f'Failed to scrape URL. Status code: {response.status_code}. Error: {error_message}') | |||
| else: | |||
| raise Exception(f'Failed to scrape URL. Status code: {response.status_code}') | |||
| def crawl_url(self, url, params=None) -> str: | |||
| start_time = time.time() | |||
| headers = self._prepare_headers() | |||
| json_data = {'url': url} | |||
| if params: | |||
| json_data.update(params) | |||
| response = self._post_request(f'{self.base_url}/v0/crawl', json_data, headers) | |||
| if response.status_code == 200: | |||
| job_id = response.json().get('jobId') | |||
| return job_id | |||
| else: | |||
| self._handle_error(response, 'start crawl job') | |||
| def check_crawl_status(self, job_id) -> dict: | |||
| headers = self._prepare_headers() | |||
| response = self._get_request(f'{self.base_url}/v0/crawl/status/{job_id}', headers) | |||
| if response.status_code == 200: | |||
| crawl_status_response = response.json() | |||
| if crawl_status_response.get('status') == 'completed': | |||
| total = crawl_status_response.get('total', 0) | |||
| if total == 0: | |||
| raise Exception('Failed to check crawl status. Error: No page found') | |||
| data = crawl_status_response.get('data', []) | |||
| url_data_list = [] | |||
| for item in data: | |||
| if isinstance(item, dict) and 'metadata' in item and 'markdown' in item: | |||
| url_data = { | |||
| 'title': item.get('metadata').get('title'), | |||
| 'description': item.get('metadata').get('description'), | |||
| 'source_url': item.get('metadata').get('sourceURL'), | |||
| 'markdown': item.get('markdown') | |||
| } | |||
| url_data_list.append(url_data) | |||
| if url_data_list: | |||
| file_key = 'website_files/' + job_id + '.txt' | |||
| if storage.exists(file_key): | |||
| storage.delete(file_key) | |||
| storage.save(file_key, json.dumps(url_data_list).encode('utf-8')) | |||
| return { | |||
| 'status': 'completed', | |||
| 'total': crawl_status_response.get('total'), | |||
| 'current': crawl_status_response.get('current'), | |||
| 'data': url_data_list | |||
| } | |||
| else: | |||
| return { | |||
| 'status': crawl_status_response.get('status'), | |||
| 'total': crawl_status_response.get('total'), | |||
| 'current': crawl_status_response.get('current'), | |||
| 'data': [] | |||
| } | |||
| else: | |||
| self._handle_error(response, 'check crawl status') | |||
| def _prepare_headers(self): | |||
| return { | |||
| 'Content-Type': 'application/json', | |||
| 'Authorization': f'Bearer {self.api_key}' | |||
| } | |||
| def _post_request(self, url, data, headers, retries=3, backoff_factor=0.5): | |||
| for attempt in range(retries): | |||
| response = requests.post(url, headers=headers, json=data) | |||
| if response.status_code == 502: | |||
| time.sleep(backoff_factor * (2 ** attempt)) | |||
| else: | |||
| return response | |||
| return response | |||
| def _get_request(self, url, headers, retries=3, backoff_factor=0.5): | |||
| for attempt in range(retries): | |||
| response = requests.get(url, headers=headers) | |||
| if response.status_code == 502: | |||
| time.sleep(backoff_factor * (2 ** attempt)) | |||
| else: | |||
| return response | |||
| return response | |||
| def _handle_error(self, response, action): | |||
| error_message = response.json().get('error', 'Unknown error occurred') | |||
| raise Exception(f'Failed to {action}. Status code: {response.status_code}. Error: {error_message}') | |||
| @@ -0,0 +1,60 @@ | |||
| from core.rag.extractor.extractor_base import BaseExtractor | |||
| from core.rag.models.document import Document | |||
| from services.website_service import WebsiteService | |||
| class FirecrawlWebExtractor(BaseExtractor): | |||
| """ | |||
| Crawl and scrape websites and return content in clean llm-ready markdown. | |||
| Args: | |||
| url: The URL to scrape. | |||
| api_key: The API key for Firecrawl. | |||
| base_url: The base URL for the Firecrawl API. Defaults to 'https://api.firecrawl.dev'. | |||
| mode: The mode of operation. Defaults to 'scrape'. Options are 'crawl', 'scrape' and 'crawl_return_urls'. | |||
| """ | |||
| def __init__( | |||
| self, | |||
| url: str, | |||
| job_id: str, | |||
| tenant_id: str, | |||
| mode: str = 'crawl', | |||
| only_main_content: bool = False | |||
| ): | |||
| """Initialize with url, api_key, base_url and mode.""" | |||
| self._url = url | |||
| self.job_id = job_id | |||
| self.tenant_id = tenant_id | |||
| self.mode = mode | |||
| self.only_main_content = only_main_content | |||
| def extract(self) -> list[Document]: | |||
| """Extract content from the URL.""" | |||
| documents = [] | |||
| if self.mode == 'crawl': | |||
| crawl_data = WebsiteService.get_crawl_url_data(self.job_id, 'firecrawl', self._url, self.tenant_id) | |||
| if crawl_data is None: | |||
| return [] | |||
| document = Document(page_content=crawl_data.get('markdown', ''), | |||
| metadata={ | |||
| 'source_url': crawl_data.get('source_url'), | |||
| 'description': crawl_data.get('description'), | |||
| 'title': crawl_data.get('title') | |||
| } | |||
| ) | |||
| documents.append(document) | |||
| elif self.mode == 'scrape': | |||
| scrape_data = WebsiteService.get_scrape_url_data('firecrawl', self._url, self.tenant_id, | |||
| self.only_main_content) | |||
| document = Document(page_content=scrape_data.get('markdown', ''), | |||
| metadata={ | |||
| 'source_url': scrape_data.get('source_url'), | |||
| 'description': scrape_data.get('description'), | |||
| 'title': scrape_data.get('title') | |||
| } | |||
| ) | |||
| documents.append(document) | |||
| return documents | |||
| @@ -9,7 +9,7 @@ from core.rag.extractor.extractor_base import BaseExtractor | |||
| from core.rag.models.document import Document | |||
| from extensions.ext_database import db | |||
| from models.dataset import Document as DocumentModel | |||
| from models.source import DataSourceBinding | |||
| from models.source import DataSourceOauthBinding | |||
| logger = logging.getLogger(__name__) | |||
| @@ -345,12 +345,12 @@ class NotionExtractor(BaseExtractor): | |||
| @classmethod | |||
| def _get_access_token(cls, tenant_id: str, notion_workspace_id: str) -> str: | |||
| data_source_binding = DataSourceBinding.query.filter( | |||
| data_source_binding = DataSourceOauthBinding.query.filter( | |||
| db.and_( | |||
| DataSourceBinding.tenant_id == tenant_id, | |||
| DataSourceBinding.provider == 'notion', | |||
| DataSourceBinding.disabled == False, | |||
| DataSourceBinding.source_info['workspace_id'] == f'"{notion_workspace_id}"' | |||
| DataSourceOauthBinding.tenant_id == tenant_id, | |||
| DataSourceOauthBinding.provider == 'notion', | |||
| DataSourceOauthBinding.disabled == False, | |||
| DataSourceOauthBinding.source_info['workspace_id'] == f'"{notion_workspace_id}"' | |||
| ) | |||
| ).first() | |||
| @@ -0,0 +1,64 @@ | |||
| # [REVIEW] Implement if Needed? Do we need a new type of data source | |||
| from abc import abstractmethod | |||
| import requests | |||
| from api.models.source import DataSourceBearerBinding | |||
| from flask_login import current_user | |||
| from extensions.ext_database import db | |||
| class BearerDataSource: | |||
| def __init__(self, api_key: str, api_base_url: str): | |||
| self.api_key = api_key | |||
| self.api_base_url = api_base_url | |||
| @abstractmethod | |||
| def validate_bearer_data_source(self): | |||
| """ | |||
| Validate the data source | |||
| """ | |||
| class FireCrawlDataSource(BearerDataSource): | |||
| def validate_bearer_data_source(self): | |||
| TEST_CRAWL_SITE_URL = "https://www.google.com" | |||
| FIRECRAWL_API_VERSION = "v0" | |||
| test_api_endpoint = self.api_base_url.rstrip('/') + f"/{FIRECRAWL_API_VERSION}/scrape" | |||
| headers = { | |||
| "Authorization": f"Bearer {self.api_key}", | |||
| "Content-Type": "application/json", | |||
| } | |||
| data = { | |||
| "url": TEST_CRAWL_SITE_URL, | |||
| } | |||
| response = requests.get(test_api_endpoint, headers=headers, json=data) | |||
| return response.json().get("status") == "success" | |||
| def save_credentials(self): | |||
| # save data source binding | |||
| data_source_binding = DataSourceBearerBinding.query.filter( | |||
| db.and_( | |||
| DataSourceBearerBinding.tenant_id == current_user.current_tenant_id, | |||
| DataSourceBearerBinding.provider == 'firecrawl', | |||
| DataSourceBearerBinding.endpoint_url == self.api_base_url, | |||
| DataSourceBearerBinding.bearer_key == self.api_key | |||
| ) | |||
| ).first() | |||
| if data_source_binding: | |||
| data_source_binding.disabled = False | |||
| db.session.commit() | |||
| else: | |||
| new_data_source_binding = DataSourceBearerBinding( | |||
| tenant_id=current_user.current_tenant_id, | |||
| provider='firecrawl', | |||
| endpoint_url=self.api_base_url, | |||
| bearer_key=self.api_key | |||
| ) | |||
| db.session.add(new_data_source_binding) | |||
| db.session.commit() | |||
| @@ -4,7 +4,7 @@ import requests | |||
| from flask_login import current_user | |||
| from extensions.ext_database import db | |||
| from models.source import DataSourceBinding | |||
| from models.source import DataSourceOauthBinding | |||
| class OAuthDataSource: | |||
| @@ -63,11 +63,11 @@ class NotionOAuth(OAuthDataSource): | |||
| 'total': len(pages) | |||
| } | |||
| # save data source binding | |||
| data_source_binding = DataSourceBinding.query.filter( | |||
| data_source_binding = DataSourceOauthBinding.query.filter( | |||
| db.and_( | |||
| DataSourceBinding.tenant_id == current_user.current_tenant_id, | |||
| DataSourceBinding.provider == 'notion', | |||
| DataSourceBinding.access_token == access_token | |||
| DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, | |||
| DataSourceOauthBinding.provider == 'notion', | |||
| DataSourceOauthBinding.access_token == access_token | |||
| ) | |||
| ).first() | |||
| if data_source_binding: | |||
| @@ -75,7 +75,7 @@ class NotionOAuth(OAuthDataSource): | |||
| data_source_binding.disabled = False | |||
| db.session.commit() | |||
| else: | |||
| new_data_source_binding = DataSourceBinding( | |||
| new_data_source_binding = DataSourceOauthBinding( | |||
| tenant_id=current_user.current_tenant_id, | |||
| access_token=access_token, | |||
| source_info=source_info, | |||
| @@ -98,11 +98,11 @@ class NotionOAuth(OAuthDataSource): | |||
| 'total': len(pages) | |||
| } | |||
| # save data source binding | |||
| data_source_binding = DataSourceBinding.query.filter( | |||
| data_source_binding = DataSourceOauthBinding.query.filter( | |||
| db.and_( | |||
| DataSourceBinding.tenant_id == current_user.current_tenant_id, | |||
| DataSourceBinding.provider == 'notion', | |||
| DataSourceBinding.access_token == access_token | |||
| DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, | |||
| DataSourceOauthBinding.provider == 'notion', | |||
| DataSourceOauthBinding.access_token == access_token | |||
| ) | |||
| ).first() | |||
| if data_source_binding: | |||
| @@ -110,7 +110,7 @@ class NotionOAuth(OAuthDataSource): | |||
| data_source_binding.disabled = False | |||
| db.session.commit() | |||
| else: | |||
| new_data_source_binding = DataSourceBinding( | |||
| new_data_source_binding = DataSourceOauthBinding( | |||
| tenant_id=current_user.current_tenant_id, | |||
| access_token=access_token, | |||
| source_info=source_info, | |||
| @@ -121,12 +121,12 @@ class NotionOAuth(OAuthDataSource): | |||
| def sync_data_source(self, binding_id: str): | |||
| # save data source binding | |||
| data_source_binding = DataSourceBinding.query.filter( | |||
| data_source_binding = DataSourceOauthBinding.query.filter( | |||
| db.and_( | |||
| DataSourceBinding.tenant_id == current_user.current_tenant_id, | |||
| DataSourceBinding.provider == 'notion', | |||
| DataSourceBinding.id == binding_id, | |||
| DataSourceBinding.disabled == False | |||
| DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, | |||
| DataSourceOauthBinding.provider == 'notion', | |||
| DataSourceOauthBinding.id == binding_id, | |||
| DataSourceOauthBinding.disabled == False | |||
| ) | |||
| ).first() | |||
| if data_source_binding: | |||
| @@ -0,0 +1,67 @@ | |||
| """add-api-key-auth-binding | |||
| Revision ID: 7b45942e39bb | |||
| Revises: 47cc7df8c4f3 | |||
| Create Date: 2024-05-14 07:31:29.702766 | |||
| """ | |||
| import sqlalchemy as sa | |||
| from alembic import op | |||
| import models as models | |||
| # revision identifiers, used by Alembic. | |||
| revision = '7b45942e39bb' | |||
| down_revision = '4e99a8df00ff' | |||
| branch_labels = None | |||
| depends_on = None | |||
| def upgrade(): | |||
| # ### commands auto generated by Alembic - please adjust! ### | |||
| op.create_table('data_source_api_key_auth_bindings', | |||
| sa.Column('id', models.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), | |||
| sa.Column('tenant_id', models.StringUUID(), nullable=False), | |||
| sa.Column('category', sa.String(length=255), nullable=False), | |||
| sa.Column('provider', sa.String(length=255), nullable=False), | |||
| sa.Column('credentials', sa.Text(), nullable=True), | |||
| sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), | |||
| sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), | |||
| sa.Column('disabled', sa.Boolean(), server_default=sa.text('false'), nullable=True), | |||
| sa.PrimaryKeyConstraint('id', name='data_source_api_key_auth_binding_pkey') | |||
| ) | |||
| with op.batch_alter_table('data_source_api_key_auth_bindings', schema=None) as batch_op: | |||
| batch_op.create_index('data_source_api_key_auth_binding_provider_idx', ['provider'], unique=False) | |||
| batch_op.create_index('data_source_api_key_auth_binding_tenant_id_idx', ['tenant_id'], unique=False) | |||
| with op.batch_alter_table('data_source_bindings', schema=None) as batch_op: | |||
| batch_op.drop_index('source_binding_tenant_id_idx') | |||
| batch_op.drop_index('source_info_idx') | |||
| op.rename_table('data_source_bindings', 'data_source_oauth_bindings') | |||
| with op.batch_alter_table('data_source_oauth_bindings', schema=None) as batch_op: | |||
| batch_op.create_index('source_binding_tenant_id_idx', ['tenant_id'], unique=False) | |||
| batch_op.create_index('source_info_idx', ['source_info'], unique=False, postgresql_using='gin') | |||
| # ### end Alembic commands ### | |||
| def downgrade(): | |||
| # ### commands auto generated by Alembic - please adjust! ### | |||
| with op.batch_alter_table('data_source_oauth_bindings', schema=None) as batch_op: | |||
| batch_op.drop_index('source_info_idx', postgresql_using='gin') | |||
| batch_op.drop_index('source_binding_tenant_id_idx') | |||
| op.rename_table('data_source_oauth_bindings', 'data_source_bindings') | |||
| with op.batch_alter_table('data_source_bindings', schema=None) as batch_op: | |||
| batch_op.create_index('source_info_idx', ['source_info'], unique=False) | |||
| batch_op.create_index('source_binding_tenant_id_idx', ['tenant_id'], unique=False) | |||
| with op.batch_alter_table('data_source_api_key_auth_bindings', schema=None) as batch_op: | |||
| batch_op.drop_index('data_source_api_key_auth_binding_tenant_id_idx') | |||
| batch_op.drop_index('data_source_api_key_auth_binding_provider_idx') | |||
| op.drop_table('data_source_api_key_auth_bindings') | |||
| # ### end Alembic commands ### | |||
| @@ -270,7 +270,7 @@ class Document(db.Model): | |||
| 255), nullable=False, server_default=db.text("'text_model'::character varying")) | |||
| doc_language = db.Column(db.String(255), nullable=True) | |||
| DATA_SOURCES = ['upload_file', 'notion_import'] | |||
| DATA_SOURCES = ['upload_file', 'notion_import', 'website_crawl'] | |||
| @property | |||
| def display_status(self): | |||
| @@ -322,7 +322,7 @@ class Document(db.Model): | |||
| 'created_at': file_detail.created_at.timestamp() | |||
| } | |||
| } | |||
| elif self.data_source_type == 'notion_import': | |||
| elif self.data_source_type == 'notion_import' or self.data_source_type == 'website_crawl': | |||
| return json.loads(self.data_source_info) | |||
| return {} | |||
| @@ -1,11 +1,13 @@ | |||
| import json | |||
| from sqlalchemy.dialects.postgresql import JSONB | |||
| from extensions.ext_database import db | |||
| from models import StringUUID | |||
| class DataSourceBinding(db.Model): | |||
| __tablename__ = 'data_source_bindings' | |||
| class DataSourceOauthBinding(db.Model): | |||
| __tablename__ = 'data_source_oauth_bindings' | |||
| __table_args__ = ( | |||
| db.PrimaryKeyConstraint('id', name='source_binding_pkey'), | |||
| db.Index('source_binding_tenant_id_idx', 'tenant_id'), | |||
| @@ -20,3 +22,33 @@ class DataSourceBinding(db.Model): | |||
| created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) | |||
| updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) | |||
| disabled = db.Column(db.Boolean, nullable=True, server_default=db.text('false')) | |||
| class DataSourceApiKeyAuthBinding(db.Model): | |||
| __tablename__ = 'data_source_api_key_auth_bindings' | |||
| __table_args__ = ( | |||
| db.PrimaryKeyConstraint('id', name='data_source_api_key_auth_binding_pkey'), | |||
| db.Index('data_source_api_key_auth_binding_tenant_id_idx', 'tenant_id'), | |||
| db.Index('data_source_api_key_auth_binding_provider_idx', 'provider'), | |||
| ) | |||
| id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()')) | |||
| tenant_id = db.Column(StringUUID, nullable=False) | |||
| category = db.Column(db.String(255), nullable=False) | |||
| provider = db.Column(db.String(255), nullable=False) | |||
| credentials = db.Column(db.Text, nullable=True) # JSON | |||
| created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) | |||
| updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) | |||
| disabled = db.Column(db.Boolean, nullable=True, server_default=db.text('false')) | |||
| def to_dict(self): | |||
| return { | |||
| 'id': self.id, | |||
| 'tenant_id': self.tenant_id, | |||
| 'category': self.category, | |||
| 'provider': self.provider, | |||
| 'credentials': json.loads(self.credentials), | |||
| 'created_at': self.created_at.timestamp(), | |||
| 'updated_at': self.updated_at.timestamp(), | |||
| 'disabled': self.disabled | |||
| } | |||
| @@ -78,6 +78,9 @@ CODE_MAX_STRING_LENGTH = "80000" | |||
| CODE_EXECUTION_ENDPOINT="http://127.0.0.1:8194" | |||
| CODE_EXECUTION_API_KEY="dify-sandbox" | |||
| FIRECRAWL_API_KEY = "fc-" | |||
| [tool.poetry] | |||
| name = "dify-api" | |||
| @@ -0,0 +1,10 @@ | |||
| from abc import ABC, abstractmethod | |||
| class ApiKeyAuthBase(ABC): | |||
| def __init__(self, credentials: dict): | |||
| self.credentials = credentials | |||
| @abstractmethod | |||
| def validate_credentials(self): | |||
| raise NotImplementedError | |||
| @@ -0,0 +1,14 @@ | |||
| from services.auth.firecrawl import FirecrawlAuth | |||
| class ApiKeyAuthFactory: | |||
| def __init__(self, provider: str, credentials: dict): | |||
| if provider == 'firecrawl': | |||
| self.auth = FirecrawlAuth(credentials) | |||
| else: | |||
| raise ValueError('Invalid provider') | |||
| def validate_credentials(self): | |||
| return self.auth.validate_credentials() | |||
| @@ -0,0 +1,70 @@ | |||
| import json | |||
| from core.helper import encrypter | |||
| from extensions.ext_database import db | |||
| from models.source import DataSourceApiKeyAuthBinding | |||
| from services.auth.api_key_auth_factory import ApiKeyAuthFactory | |||
| class ApiKeyAuthService: | |||
| @staticmethod | |||
| def get_provider_auth_list(tenant_id: str) -> list: | |||
| data_source_api_key_bindings = db.session.query(DataSourceApiKeyAuthBinding).filter( | |||
| DataSourceApiKeyAuthBinding.tenant_id == tenant_id, | |||
| DataSourceApiKeyAuthBinding.disabled.is_(False) | |||
| ).all() | |||
| return data_source_api_key_bindings | |||
| @staticmethod | |||
| def create_provider_auth(tenant_id: str, args: dict): | |||
| auth_result = ApiKeyAuthFactory(args['provider'], args['credentials']).validate_credentials() | |||
| if auth_result: | |||
| # Encrypt the api key | |||
| api_key = encrypter.encrypt_token(tenant_id, args['credentials']['config']['api_key']) | |||
| args['credentials']['config']['api_key'] = api_key | |||
| data_source_api_key_binding = DataSourceApiKeyAuthBinding() | |||
| data_source_api_key_binding.tenant_id = tenant_id | |||
| data_source_api_key_binding.category = args['category'] | |||
| data_source_api_key_binding.provider = args['provider'] | |||
| data_source_api_key_binding.credentials = json.dumps(args['credentials'], ensure_ascii=False) | |||
| db.session.add(data_source_api_key_binding) | |||
| db.session.commit() | |||
| @staticmethod | |||
| def get_auth_credentials(tenant_id: str, category: str, provider: str): | |||
| data_source_api_key_bindings = db.session.query(DataSourceApiKeyAuthBinding).filter( | |||
| DataSourceApiKeyAuthBinding.tenant_id == tenant_id, | |||
| DataSourceApiKeyAuthBinding.category == category, | |||
| DataSourceApiKeyAuthBinding.provider == provider, | |||
| DataSourceApiKeyAuthBinding.disabled.is_(False) | |||
| ).first() | |||
| if not data_source_api_key_bindings: | |||
| return None | |||
| credentials = json.loads(data_source_api_key_bindings.credentials) | |||
| return credentials | |||
| @staticmethod | |||
| def delete_provider_auth(tenant_id: str, binding_id: str): | |||
| data_source_api_key_binding = db.session.query(DataSourceApiKeyAuthBinding).filter( | |||
| DataSourceApiKeyAuthBinding.tenant_id == tenant_id, | |||
| DataSourceApiKeyAuthBinding.id == binding_id | |||
| ).first() | |||
| if data_source_api_key_binding: | |||
| db.session.delete(data_source_api_key_binding) | |||
| db.session.commit() | |||
| @classmethod | |||
| def validate_api_key_auth_args(cls, args): | |||
| if 'category' not in args or not args['category']: | |||
| raise ValueError('category is required') | |||
| if 'provider' not in args or not args['provider']: | |||
| raise ValueError('provider is required') | |||
| if 'credentials' not in args or not args['credentials']: | |||
| raise ValueError('credentials is required') | |||
| if not isinstance(args['credentials'], dict): | |||
| raise ValueError('credentials must be a dictionary') | |||
| if 'auth_type' not in args['credentials'] or not args['credentials']['auth_type']: | |||
| raise ValueError('auth_type is required') | |||
| @@ -0,0 +1,56 @@ | |||
| import json | |||
| import requests | |||
| from services.auth.api_key_auth_base import ApiKeyAuthBase | |||
| class FirecrawlAuth(ApiKeyAuthBase): | |||
| def __init__(self, credentials: dict): | |||
| super().__init__(credentials) | |||
| auth_type = credentials.get('auth_type') | |||
| if auth_type != 'bearer': | |||
| raise ValueError('Invalid auth type, Firecrawl auth type must be Bearer') | |||
| self.api_key = credentials.get('config').get('api_key', None) | |||
| self.base_url = credentials.get('config').get('base_url', 'https://api.firecrawl.dev') | |||
| if not self.api_key: | |||
| raise ValueError('No API key provided') | |||
| def validate_credentials(self): | |||
| headers = self._prepare_headers() | |||
| options = { | |||
| 'url': 'https://example.com', | |||
| 'crawlerOptions': { | |||
| 'excludes': [], | |||
| 'includes': [], | |||
| 'limit': 1 | |||
| }, | |||
| 'pageOptions': { | |||
| 'onlyMainContent': True | |||
| } | |||
| } | |||
| response = self._post_request(f'{self.base_url}/v0/crawl', options, headers) | |||
| if response.status_code == 200: | |||
| return True | |||
| else: | |||
| self._handle_error(response) | |||
| def _prepare_headers(self): | |||
| return { | |||
| 'Content-Type': 'application/json', | |||
| 'Authorization': f'Bearer {self.api_key}' | |||
| } | |||
| def _post_request(self, url, data, headers): | |||
| return requests.post(url, headers=headers, json=data) | |||
| def _handle_error(self, response): | |||
| if response.status_code in [402, 409, 500]: | |||
| error_message = response.json().get('error', 'Unknown error occurred') | |||
| raise Exception(f'Failed to authorize. Status code: {response.status_code}. Error: {error_message}') | |||
| else: | |||
| if response.text: | |||
| error_message = json.loads(response.text).get('error', 'Unknown error occurred') | |||
| raise Exception(f'Failed to authorize. Status code: {response.status_code}. Error: {error_message}') | |||
| raise Exception(f'Unexpected error occurred while trying to authorize. Status code: {response.status_code}') | |||
| @@ -31,7 +31,7 @@ from models.dataset import ( | |||
| DocumentSegment, | |||
| ) | |||
| from models.model import UploadFile | |||
| from models.source import DataSourceBinding | |||
| from models.source import DataSourceOauthBinding | |||
| from services.errors.account import NoPermissionError | |||
| from services.errors.dataset import DatasetInUseError, DatasetNameDuplicateError | |||
| from services.errors.document import DocumentIndexingError | |||
| @@ -48,6 +48,7 @@ from tasks.document_indexing_update_task import document_indexing_update_task | |||
| from tasks.duplicate_document_indexing_task import duplicate_document_indexing_task | |||
| from tasks.recover_document_indexing_task import recover_document_indexing_task | |||
| from tasks.retry_document_indexing_task import retry_document_indexing_task | |||
| from tasks.sync_website_document_indexing_task import sync_website_document_indexing_task | |||
| class DatasetService: | |||
| @@ -508,18 +509,40 @@ class DocumentService: | |||
| @staticmethod | |||
| def retry_document(dataset_id: str, documents: list[Document]): | |||
| for document in documents: | |||
| # add retry flag | |||
| retry_indexing_cache_key = 'document_{}_is_retried'.format(document.id) | |||
| cache_result = redis_client.get(retry_indexing_cache_key) | |||
| if cache_result is not None: | |||
| raise ValueError("Document is being retried, please try again later") | |||
| # retry document indexing | |||
| document.indexing_status = 'waiting' | |||
| db.session.add(document) | |||
| db.session.commit() | |||
| # add retry flag | |||
| retry_indexing_cache_key = 'document_{}_is_retried'.format(document.id) | |||
| redis_client.setex(retry_indexing_cache_key, 600, 1) | |||
| # trigger async task | |||
| document_ids = [document.id for document in documents] | |||
| retry_document_indexing_task.delay(dataset_id, document_ids) | |||
| @staticmethod | |||
| def sync_website_document(dataset_id: str, document: Document): | |||
| # add sync flag | |||
| sync_indexing_cache_key = 'document_{}_is_sync'.format(document.id) | |||
| cache_result = redis_client.get(sync_indexing_cache_key) | |||
| if cache_result is not None: | |||
| raise ValueError("Document is being synced, please try again later") | |||
| # sync document indexing | |||
| document.indexing_status = 'waiting' | |||
| data_source_info = document.data_source_info_dict | |||
| data_source_info['mode'] = 'scrape' | |||
| document.data_source_info = json.dumps(data_source_info, ensure_ascii=False) | |||
| db.session.add(document) | |||
| db.session.commit() | |||
| redis_client.setex(sync_indexing_cache_key, 600, 1) | |||
| sync_website_document_indexing_task.delay(dataset_id, document.id) | |||
| @staticmethod | |||
| def get_documents_position(dataset_id): | |||
| document = Document.query.filter_by(dataset_id=dataset_id).order_by(Document.position.desc()).first() | |||
| if document: | |||
| @@ -545,6 +568,9 @@ class DocumentService: | |||
| notion_info_list = document_data["data_source"]['info_list']['notion_info_list'] | |||
| for notion_info in notion_info_list: | |||
| count = count + len(notion_info['pages']) | |||
| elif document_data["data_source"]["type"] == "website_crawl": | |||
| website_info = document_data["data_source"]['info_list']['website_info_list'] | |||
| count = len(website_info['urls']) | |||
| batch_upload_limit = int(current_app.config['BATCH_UPLOAD_LIMIT']) | |||
| if count > batch_upload_limit: | |||
| raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.") | |||
| @@ -683,12 +709,12 @@ class DocumentService: | |||
| exist_document[data_source_info['notion_page_id']] = document.id | |||
| for notion_info in notion_info_list: | |||
| workspace_id = notion_info['workspace_id'] | |||
| data_source_binding = DataSourceBinding.query.filter( | |||
| data_source_binding = DataSourceOauthBinding.query.filter( | |||
| db.and_( | |||
| DataSourceBinding.tenant_id == current_user.current_tenant_id, | |||
| DataSourceBinding.provider == 'notion', | |||
| DataSourceBinding.disabled == False, | |||
| DataSourceBinding.source_info['workspace_id'] == f'"{workspace_id}"' | |||
| DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, | |||
| DataSourceOauthBinding.provider == 'notion', | |||
| DataSourceOauthBinding.disabled == False, | |||
| DataSourceOauthBinding.source_info['workspace_id'] == f'"{workspace_id}"' | |||
| ) | |||
| ).first() | |||
| if not data_source_binding: | |||
| @@ -717,6 +743,28 @@ class DocumentService: | |||
| # delete not selected documents | |||
| if len(exist_document) > 0: | |||
| clean_notion_document_task.delay(list(exist_document.values()), dataset.id) | |||
| elif document_data["data_source"]["type"] == "website_crawl": | |||
| website_info = document_data["data_source"]['info_list']['website_info_list'] | |||
| urls = website_info['urls'] | |||
| for url in urls: | |||
| data_source_info = { | |||
| 'url': url, | |||
| 'provider': website_info['provider'], | |||
| 'job_id': website_info['job_id'], | |||
| 'only_main_content': website_info.get('only_main_content', False), | |||
| 'mode': 'crawl', | |||
| } | |||
| document = DocumentService.build_document(dataset, dataset_process_rule.id, | |||
| document_data["data_source"]["type"], | |||
| document_data["doc_form"], | |||
| document_data["doc_language"], | |||
| data_source_info, created_from, position, | |||
| account, url, batch) | |||
| db.session.add(document) | |||
| db.session.flush() | |||
| document_ids.append(document.id) | |||
| documents.append(document) | |||
| position += 1 | |||
| db.session.commit() | |||
| # trigger async task | |||
| @@ -818,12 +866,12 @@ class DocumentService: | |||
| notion_info_list = document_data["data_source"]['info_list']['notion_info_list'] | |||
| for notion_info in notion_info_list: | |||
| workspace_id = notion_info['workspace_id'] | |||
| data_source_binding = DataSourceBinding.query.filter( | |||
| data_source_binding = DataSourceOauthBinding.query.filter( | |||
| db.and_( | |||
| DataSourceBinding.tenant_id == current_user.current_tenant_id, | |||
| DataSourceBinding.provider == 'notion', | |||
| DataSourceBinding.disabled == False, | |||
| DataSourceBinding.source_info['workspace_id'] == f'"{workspace_id}"' | |||
| DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, | |||
| DataSourceOauthBinding.provider == 'notion', | |||
| DataSourceOauthBinding.disabled == False, | |||
| DataSourceOauthBinding.source_info['workspace_id'] == f'"{workspace_id}"' | |||
| ) | |||
| ).first() | |||
| if not data_source_binding: | |||
| @@ -835,6 +883,17 @@ class DocumentService: | |||
| "notion_page_icon": page['page_icon'], | |||
| "type": page['type'] | |||
| } | |||
| elif document_data["data_source"]["type"] == "website_crawl": | |||
| website_info = document_data["data_source"]['info_list']['website_info_list'] | |||
| urls = website_info['urls'] | |||
| for url in urls: | |||
| data_source_info = { | |||
| 'url': url, | |||
| 'provider': website_info['provider'], | |||
| 'job_id': website_info['job_id'], | |||
| 'only_main_content': website_info.get('only_main_content', False), | |||
| 'mode': 'crawl', | |||
| } | |||
| document.data_source_type = document_data["data_source"]["type"] | |||
| document.data_source_info = json.dumps(data_source_info) | |||
| document.name = file_name | |||
| @@ -873,6 +932,9 @@ class DocumentService: | |||
| notion_info_list = document_data["data_source"]['info_list']['notion_info_list'] | |||
| for notion_info in notion_info_list: | |||
| count = count + len(notion_info['pages']) | |||
| elif document_data["data_source"]["type"] == "website_crawl": | |||
| website_info = document_data["data_source"]['info_list']['website_info_list'] | |||
| count = len(website_info['urls']) | |||
| batch_upload_limit = int(current_app.config['BATCH_UPLOAD_LIMIT']) | |||
| if count > batch_upload_limit: | |||
| raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.") | |||
| @@ -973,6 +1035,10 @@ class DocumentService: | |||
| if 'notion_info_list' not in args['data_source']['info_list'] or not args['data_source']['info_list'][ | |||
| 'notion_info_list']: | |||
| raise ValueError("Notion source info is required") | |||
| if args['data_source']['type'] == 'website_crawl': | |||
| if 'website_info_list' not in args['data_source']['info_list'] or not args['data_source']['info_list'][ | |||
| 'website_info_list']: | |||
| raise ValueError("Website source info is required") | |||
| @classmethod | |||
| def process_rule_args_validate(cls, args: dict): | |||
| @@ -0,0 +1,171 @@ | |||
| import datetime | |||
| import json | |||
| from flask_login import current_user | |||
| from core.helper import encrypter | |||
| from core.rag.extractor.firecrawl.firecrawl_app import FirecrawlApp | |||
| from extensions.ext_redis import redis_client | |||
| from extensions.ext_storage import storage | |||
| from services.auth.api_key_auth_service import ApiKeyAuthService | |||
| class WebsiteService: | |||
| @classmethod | |||
| def document_create_args_validate(cls, args: dict): | |||
| if 'url' not in args or not args['url']: | |||
| raise ValueError('url is required') | |||
| if 'options' not in args or not args['options']: | |||
| raise ValueError('options is required') | |||
| if 'limit' not in args['options'] or not args['options']['limit']: | |||
| raise ValueError('limit is required') | |||
| @classmethod | |||
| def crawl_url(cls, args: dict) -> dict: | |||
| provider = args.get('provider') | |||
| url = args.get('url') | |||
| options = args.get('options') | |||
| credentials = ApiKeyAuthService.get_auth_credentials(current_user.current_tenant_id, | |||
| 'website', | |||
| provider) | |||
| if provider == 'firecrawl': | |||
| # decrypt api_key | |||
| api_key = encrypter.decrypt_token( | |||
| tenant_id=current_user.current_tenant_id, | |||
| token=credentials.get('config').get('api_key') | |||
| ) | |||
| firecrawl_app = FirecrawlApp(api_key=api_key, | |||
| base_url=credentials.get('config').get('base_url', None)) | |||
| crawl_sub_pages = options.get('crawl_sub_pages', False) | |||
| only_main_content = options.get('only_main_content', False) | |||
| if not crawl_sub_pages: | |||
| params = { | |||
| 'crawlerOptions': { | |||
| "includes": [], | |||
| "excludes": [], | |||
| "generateImgAltText": True, | |||
| "limit": 1, | |||
| 'returnOnlyUrls': False, | |||
| 'pageOptions': { | |||
| 'onlyMainContent': only_main_content, | |||
| "includeHtml": False | |||
| } | |||
| } | |||
| } | |||
| else: | |||
| includes = options.get('includes').split(',') if options.get('includes') else [] | |||
| excludes = options.get('excludes').split(',') if options.get('excludes') else [] | |||
| params = { | |||
| 'crawlerOptions': { | |||
| "includes": includes if includes else [], | |||
| "excludes": excludes if excludes else [], | |||
| "generateImgAltText": True, | |||
| "limit": options.get('limit', 1), | |||
| 'returnOnlyUrls': False, | |||
| 'pageOptions': { | |||
| 'onlyMainContent': only_main_content, | |||
| "includeHtml": False | |||
| } | |||
| } | |||
| } | |||
| if options.get('max_depth'): | |||
| params['crawlerOptions']['maxDepth'] = options.get('max_depth') | |||
| job_id = firecrawl_app.crawl_url(url, params) | |||
| website_crawl_time_cache_key = f'website_crawl_{job_id}' | |||
| time = str(datetime.datetime.now().timestamp()) | |||
| redis_client.setex(website_crawl_time_cache_key, 3600, time) | |||
| return { | |||
| 'status': 'active', | |||
| 'job_id': job_id | |||
| } | |||
| else: | |||
| raise ValueError('Invalid provider') | |||
| @classmethod | |||
| def get_crawl_status(cls, job_id: str, provider: str) -> dict: | |||
| credentials = ApiKeyAuthService.get_auth_credentials(current_user.current_tenant_id, | |||
| 'website', | |||
| provider) | |||
| if provider == 'firecrawl': | |||
| # decrypt api_key | |||
| api_key = encrypter.decrypt_token( | |||
| tenant_id=current_user.current_tenant_id, | |||
| token=credentials.get('config').get('api_key') | |||
| ) | |||
| firecrawl_app = FirecrawlApp(api_key=api_key, | |||
| base_url=credentials.get('config').get('base_url', None)) | |||
| result = firecrawl_app.check_crawl_status(job_id) | |||
| crawl_status_data = { | |||
| 'status': result.get('status', 'active'), | |||
| 'job_id': job_id, | |||
| 'total': result.get('total', 0), | |||
| 'current': result.get('current', 0), | |||
| 'data': result.get('data', []) | |||
| } | |||
| if crawl_status_data['status'] == 'completed': | |||
| website_crawl_time_cache_key = f'website_crawl_{job_id}' | |||
| start_time = redis_client.get(website_crawl_time_cache_key) | |||
| if start_time: | |||
| end_time = datetime.datetime.now().timestamp() | |||
| time_consuming = abs(end_time - float(start_time)) | |||
| crawl_status_data['time_consuming'] = f"{time_consuming:.2f}" | |||
| redis_client.delete(website_crawl_time_cache_key) | |||
| else: | |||
| raise ValueError('Invalid provider') | |||
| return crawl_status_data | |||
| @classmethod | |||
| def get_crawl_url_data(cls, job_id: str, provider: str, url: str, tenant_id: str) -> dict | None: | |||
| credentials = ApiKeyAuthService.get_auth_credentials(tenant_id, | |||
| 'website', | |||
| provider) | |||
| if provider == 'firecrawl': | |||
| file_key = 'website_files/' + job_id + '.txt' | |||
| if storage.exists(file_key): | |||
| data = storage.load_once(file_key) | |||
| if data: | |||
| data = json.loads(data.decode('utf-8')) | |||
| else: | |||
| # decrypt api_key | |||
| api_key = encrypter.decrypt_token( | |||
| tenant_id=tenant_id, | |||
| token=credentials.get('config').get('api_key') | |||
| ) | |||
| firecrawl_app = FirecrawlApp(api_key=api_key, | |||
| base_url=credentials.get('config').get('base_url', None)) | |||
| result = firecrawl_app.check_crawl_status(job_id) | |||
| if result.get('status') != 'completed': | |||
| raise ValueError('Crawl job is not completed') | |||
| data = result.get('data') | |||
| if data: | |||
| for item in data: | |||
| if item.get('source_url') == url: | |||
| return item | |||
| return None | |||
| else: | |||
| raise ValueError('Invalid provider') | |||
| @classmethod | |||
| def get_scrape_url_data(cls, provider: str, url: str, tenant_id: str, only_main_content: bool) -> dict | None: | |||
| credentials = ApiKeyAuthService.get_auth_credentials(tenant_id, | |||
| 'website', | |||
| provider) | |||
| if provider == 'firecrawl': | |||
| # decrypt api_key | |||
| api_key = encrypter.decrypt_token( | |||
| tenant_id=tenant_id, | |||
| token=credentials.get('config').get('api_key') | |||
| ) | |||
| firecrawl_app = FirecrawlApp(api_key=api_key, | |||
| base_url=credentials.get('config').get('base_url', None)) | |||
| params = { | |||
| 'pageOptions': { | |||
| 'onlyMainContent': only_main_content, | |||
| "includeHtml": False | |||
| } | |||
| } | |||
| result = firecrawl_app.scrape_url(url, params) | |||
| return result | |||
| else: | |||
| raise ValueError('Invalid provider') | |||
| @@ -11,7 +11,7 @@ from core.rag.extractor.notion_extractor import NotionExtractor | |||
| from core.rag.index_processor.index_processor_factory import IndexProcessorFactory | |||
| from extensions.ext_database import db | |||
| from models.dataset import Dataset, Document, DocumentSegment | |||
| from models.source import DataSourceBinding | |||
| from models.source import DataSourceOauthBinding | |||
| @shared_task(queue='dataset') | |||
| @@ -43,12 +43,12 @@ def document_indexing_sync_task(dataset_id: str, document_id: str): | |||
| page_id = data_source_info['notion_page_id'] | |||
| page_type = data_source_info['type'] | |||
| page_edited_time = data_source_info['last_edited_time'] | |||
| data_source_binding = DataSourceBinding.query.filter( | |||
| data_source_binding = DataSourceOauthBinding.query.filter( | |||
| db.and_( | |||
| DataSourceBinding.tenant_id == document.tenant_id, | |||
| DataSourceBinding.provider == 'notion', | |||
| DataSourceBinding.disabled == False, | |||
| DataSourceBinding.source_info['workspace_id'] == f'"{workspace_id}"' | |||
| DataSourceOauthBinding.tenant_id == document.tenant_id, | |||
| DataSourceOauthBinding.provider == 'notion', | |||
| DataSourceOauthBinding.disabled == False, | |||
| DataSourceOauthBinding.source_info['workspace_id'] == f'"{workspace_id}"' | |||
| ) | |||
| ).first() | |||
| if not data_source_binding: | |||
| @@ -0,0 +1,90 @@ | |||
| import datetime | |||
| import logging | |||
| import time | |||
| import click | |||
| from celery import shared_task | |||
| from core.indexing_runner import IndexingRunner | |||
| from core.rag.index_processor.index_processor_factory import IndexProcessorFactory | |||
| from extensions.ext_database import db | |||
| from extensions.ext_redis import redis_client | |||
| from models.dataset import Dataset, Document, DocumentSegment | |||
| from services.feature_service import FeatureService | |||
| @shared_task(queue='dataset') | |||
| def sync_website_document_indexing_task(dataset_id: str, document_id: str): | |||
| """ | |||
| Async process document | |||
| :param dataset_id: | |||
| :param document_id: | |||
| Usage: sunc_website_document_indexing_task.delay(dataset_id, document_id) | |||
| """ | |||
| start_at = time.perf_counter() | |||
| dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() | |||
| sync_indexing_cache_key = 'document_{}_is_sync'.format(document_id) | |||
| # check document limit | |||
| features = FeatureService.get_features(dataset.tenant_id) | |||
| try: | |||
| if features.billing.enabled: | |||
| vector_space = features.vector_space | |||
| if 0 < vector_space.limit <= vector_space.size: | |||
| raise ValueError("Your total number of documents plus the number of uploads have over the limit of " | |||
| "your subscription.") | |||
| except Exception as e: | |||
| document = db.session.query(Document).filter( | |||
| Document.id == document_id, | |||
| Document.dataset_id == dataset_id | |||
| ).first() | |||
| if document: | |||
| document.indexing_status = 'error' | |||
| document.error = str(e) | |||
| document.stopped_at = datetime.datetime.utcnow() | |||
| db.session.add(document) | |||
| db.session.commit() | |||
| redis_client.delete(sync_indexing_cache_key) | |||
| return | |||
| logging.info(click.style('Start sync website document: {}'.format(document_id), fg='green')) | |||
| document = db.session.query(Document).filter( | |||
| Document.id == document_id, | |||
| Document.dataset_id == dataset_id | |||
| ).first() | |||
| try: | |||
| if document: | |||
| # clean old data | |||
| index_processor = IndexProcessorFactory(document.doc_form).init_index_processor() | |||
| segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all() | |||
| if segments: | |||
| index_node_ids = [segment.index_node_id for segment in segments] | |||
| # delete from vector index | |||
| index_processor.clean(dataset, index_node_ids) | |||
| for segment in segments: | |||
| db.session.delete(segment) | |||
| db.session.commit() | |||
| document.indexing_status = 'parsing' | |||
| document.processing_started_at = datetime.datetime.utcnow() | |||
| db.session.add(document) | |||
| db.session.commit() | |||
| indexing_runner = IndexingRunner() | |||
| indexing_runner.run([document]) | |||
| redis_client.delete(sync_indexing_cache_key) | |||
| except Exception as ex: | |||
| document.indexing_status = 'error' | |||
| document.error = str(ex) | |||
| document.stopped_at = datetime.datetime.utcnow() | |||
| db.session.add(document) | |||
| db.session.commit() | |||
| logging.info(click.style(str(ex), fg='yellow')) | |||
| redis_client.delete(sync_indexing_cache_key) | |||
| pass | |||
| end_at = time.perf_counter() | |||
| logging.info(click.style('Sync document: {} latency: {}'.format(document_id, end_at - start_at), fg='green')) | |||
| @@ -0,0 +1,33 @@ | |||
| import os | |||
| from unittest import mock | |||
| from core.rag.extractor.firecrawl.firecrawl_app import FirecrawlApp | |||
| from core.rag.extractor.firecrawl.firecrawl_web_extractor import FirecrawlWebExtractor | |||
| from core.rag.models.document import Document | |||
| from tests.unit_tests.core.rag.extractor.test_notion_extractor import _mock_response | |||
| def test_firecrawl_web_extractor_crawl_mode(mocker): | |||
| url = "https://firecrawl.dev" | |||
| api_key = os.getenv('FIRECRAWL_API_KEY') or 'fc-' | |||
| base_url = 'https://api.firecrawl.dev' | |||
| firecrawl_app = FirecrawlApp(api_key=api_key, | |||
| base_url=base_url) | |||
| params = { | |||
| 'crawlerOptions': { | |||
| "includes": [], | |||
| "excludes": [], | |||
| "generateImgAltText": True, | |||
| "maxDepth": 1, | |||
| "limit": 1, | |||
| 'returnOnlyUrls': False, | |||
| } | |||
| } | |||
| mocked_firecrawl = { | |||
| "jobId": "test", | |||
| } | |||
| mocker.patch("requests.post", return_value=_mock_response(mocked_firecrawl)) | |||
| job_id = firecrawl_app.crawl_url(url, params) | |||
| print(job_id) | |||
| assert isinstance(job_id, str) | |||