| if: steps.changed-files.outputs.any_changed == 'true' | if: steps.changed-files.outputs.any_changed == 'true' | ||||
| run: | | run: | | ||||
| uv run --directory api ruff --version | uv run --directory api ruff --version | ||||
| uv run --directory api ruff check --diff ./ | |||||
| uv run --directory api ruff format --check --diff ./ | |||||
| uv run --directory api ruff check ./ | |||||
| uv run --directory api ruff format --check ./ | |||||
| - name: Dotenv check | - name: Dotenv check | ||||
| if: steps.changed-files.outputs.any_changed == 'true' | if: steps.changed-files.outputs.any_changed == 'true' |
| "S301", # suspicious-pickle-usage, disallow use of `pickle` and its wrappers. | "S301", # suspicious-pickle-usage, disallow use of `pickle` and its wrappers. | ||||
| "S302", # suspicious-marshal-usage, disallow use of `marshal` module | "S302", # suspicious-marshal-usage, disallow use of `marshal` module | ||||
| "S311", # suspicious-non-cryptographic-random-usage | "S311", # suspicious-non-cryptographic-random-usage | ||||
| "G001", # don't use str format to logging messages | |||||
| "G004", # don't use f-strings to format logging messages | |||||
| ] | ] | ||||
| ignore = [ | ignore = [ |
| initialize_extensions(app) | initialize_extensions(app) | ||||
| end_time = time.perf_counter() | end_time = time.perf_counter() | ||||
| if dify_config.DEBUG: | if dify_config.DEBUG: | ||||
| logging.info(f"Finished create_app ({round((end_time - start_time) * 1000, 2)} ms)") | |||||
| logging.info("Finished create_app (%s ms)", round((end_time - start_time) * 1000, 2)) | |||||
| return app | return app | ||||
| is_enabled = ext.is_enabled() if hasattr(ext, "is_enabled") else True | is_enabled = ext.is_enabled() if hasattr(ext, "is_enabled") else True | ||||
| if not is_enabled: | if not is_enabled: | ||||
| if dify_config.DEBUG: | if dify_config.DEBUG: | ||||
| logging.info(f"Skipped {short_name}") | |||||
| logging.info("Skipped %s", short_name) | |||||
| continue | continue | ||||
| start_time = time.perf_counter() | start_time = time.perf_counter() | ||||
| ext.init_app(app) | ext.init_app(app) | ||||
| end_time = time.perf_counter() | end_time = time.perf_counter() | ||||
| if dify_config.DEBUG: | if dify_config.DEBUG: | ||||
| logging.info(f"Loaded {short_name} ({round((end_time - start_time) * 1000, 2)} ms)") | |||||
| logging.info("Loaded %s (%s ms)", short_name, round((end_time - start_time) * 1000, 2)) | |||||
| def create_migrations_app(): | def create_migrations_app(): |
| account = db.session.query(Account).where(Account.email == email).one_or_none() | account = db.session.query(Account).where(Account.email == email).one_or_none() | ||||
| if not account: | if not account: | ||||
| click.echo(click.style("Account not found for email: {}".format(email), fg="red")) | |||||
| click.echo(click.style(f"Account not found for email: {email}", fg="red")) | |||||
| return | return | ||||
| try: | try: | ||||
| valid_password(new_password) | valid_password(new_password) | ||||
| except: | except: | ||||
| click.echo(click.style("Invalid password. Must match {}".format(password_pattern), fg="red")) | |||||
| click.echo(click.style(f"Invalid password. Must match {password_pattern}", fg="red")) | |||||
| return | return | ||||
| # generate password salt | # generate password salt | ||||
| account = db.session.query(Account).where(Account.email == email).one_or_none() | account = db.session.query(Account).where(Account.email == email).one_or_none() | ||||
| if not account: | if not account: | ||||
| click.echo(click.style("Account not found for email: {}".format(email), fg="red")) | |||||
| click.echo(click.style(f"Account not found for email: {email}", fg="red")) | |||||
| return | return | ||||
| try: | try: | ||||
| email_validate(new_email) | email_validate(new_email) | ||||
| except: | except: | ||||
| click.echo(click.style("Invalid email: {}".format(new_email), fg="red")) | |||||
| click.echo(click.style(f"Invalid email: {new_email}", fg="red")) | |||||
| return | return | ||||
| account.email = new_email | account.email = new_email | ||||
| click.echo( | click.echo( | ||||
| click.style( | click.style( | ||||
| "Congratulations! The asymmetric key pair of workspace {} has been reset.".format(tenant.id), | |||||
| f"Congratulations! The asymmetric key pair of workspace {tenant.id} has been reset.", | |||||
| fg="green", | fg="green", | ||||
| ) | ) | ||||
| ) | ) | ||||
| f"Processing the {total_count} app {app.id}. " + f"{create_count} created, {skipped_count} skipped." | f"Processing the {total_count} app {app.id}. " + f"{create_count} created, {skipped_count} skipped." | ||||
| ) | ) | ||||
| try: | try: | ||||
| click.echo("Creating app annotation index: {}".format(app.id)) | |||||
| click.echo(f"Creating app annotation index: {app.id}") | |||||
| app_annotation_setting = ( | app_annotation_setting = ( | ||||
| db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app.id).first() | db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app.id).first() | ||||
| ) | ) | ||||
| if not app_annotation_setting: | if not app_annotation_setting: | ||||
| skipped_count = skipped_count + 1 | skipped_count = skipped_count + 1 | ||||
| click.echo("App annotation setting disabled: {}".format(app.id)) | |||||
| click.echo(f"App annotation setting disabled: {app.id}") | |||||
| continue | continue | ||||
| # get dataset_collection_binding info | # get dataset_collection_binding info | ||||
| dataset_collection_binding = ( | dataset_collection_binding = ( | ||||
| .first() | .first() | ||||
| ) | ) | ||||
| if not dataset_collection_binding: | if not dataset_collection_binding: | ||||
| click.echo("App annotation collection binding not found: {}".format(app.id)) | |||||
| click.echo(f"App annotation collection binding not found: {app.id}") | |||||
| continue | continue | ||||
| annotations = db.session.query(MessageAnnotation).where(MessageAnnotation.app_id == app.id).all() | annotations = db.session.query(MessageAnnotation).where(MessageAnnotation.app_id == app.id).all() | ||||
| dataset = Dataset( | dataset = Dataset( | ||||
| create_count += 1 | create_count += 1 | ||||
| except Exception as e: | except Exception as e: | ||||
| click.echo( | click.echo( | ||||
| click.style( | |||||
| "Error creating app annotation index: {} {}".format(e.__class__.__name__, str(e)), fg="red" | |||||
| ) | |||||
| click.style(f"Error creating app annotation index: {e.__class__.__name__} {str(e)}", fg="red") | |||||
| ) | ) | ||||
| continue | continue | ||||
| f"Processing the {total_count} dataset {dataset.id}. {create_count} created, {skipped_count} skipped." | f"Processing the {total_count} dataset {dataset.id}. {create_count} created, {skipped_count} skipped." | ||||
| ) | ) | ||||
| try: | try: | ||||
| click.echo("Creating dataset vector database index: {}".format(dataset.id)) | |||||
| click.echo(f"Creating dataset vector database index: {dataset.id}") | |||||
| if dataset.index_struct_dict: | if dataset.index_struct_dict: | ||||
| if dataset.index_struct_dict["type"] == vector_type: | if dataset.index_struct_dict["type"] == vector_type: | ||||
| skipped_count = skipped_count + 1 | skipped_count = skipped_count + 1 | ||||
| create_count += 1 | create_count += 1 | ||||
| except Exception as e: | except Exception as e: | ||||
| db.session.rollback() | db.session.rollback() | ||||
| click.echo( | |||||
| click.style("Error creating dataset index: {} {}".format(e.__class__.__name__, str(e)), fg="red") | |||||
| ) | |||||
| click.echo(click.style(f"Error creating dataset index: {e.__class__.__name__} {str(e)}", fg="red")) | |||||
| continue | continue | ||||
| click.echo( | click.echo( | ||||
| break | break | ||||
| for app in apps: | for app in apps: | ||||
| click.echo("Converting app: {}".format(app.id)) | |||||
| click.echo(f"Converting app: {app.id}") | |||||
| try: | try: | ||||
| app.mode = AppMode.AGENT_CHAT.value | app.mode = AppMode.AGENT_CHAT.value | ||||
| ) | ) | ||||
| db.session.commit() | db.session.commit() | ||||
| click.echo(click.style("Converted app: {}".format(app.id), fg="green")) | |||||
| click.echo(click.style(f"Converted app: {app.id}", fg="green")) | |||||
| except Exception as e: | except Exception as e: | ||||
| click.echo(click.style("Convert app error: {} {}".format(e.__class__.__name__, str(e)), fg="red")) | |||||
| click.echo(click.style(f"Convert app error: {e.__class__.__name__} {str(e)}", fg="red")) | |||||
| click.echo(click.style("Conversion complete. Converted {} agent apps.".format(len(proceeded_app_ids)), fg="green")) | |||||
| click.echo(click.style(f"Conversion complete. Converted {len(proceeded_app_ids)} agent apps.", fg="green")) | |||||
| @click.command("add-qdrant-index", help="Add Qdrant index.") | @click.command("add-qdrant-index", help="Add Qdrant index.") | ||||
| click.echo( | click.echo( | ||||
| click.style( | click.style( | ||||
| "Account and tenant created.\nAccount: {}\nPassword: {}".format(email, new_password), | |||||
| f"Account and tenant created.\nAccount: {email}\nPassword: {new_password}", | |||||
| fg="green", | fg="green", | ||||
| ) | ) | ||||
| ) | ) | ||||
| if tenant: | if tenant: | ||||
| accounts = tenant.get_accounts() | accounts = tenant.get_accounts() | ||||
| if not accounts: | if not accounts: | ||||
| print("Fix failed for app {}".format(app.id)) | |||||
| print(f"Fix failed for app {app.id}") | |||||
| continue | continue | ||||
| account = accounts[0] | account = accounts[0] | ||||
| print("Fixing missing site for app {}".format(app.id)) | |||||
| print(f"Fixing missing site for app {app.id}") | |||||
| app_was_created.send(app, account=account) | app_was_created.send(app, account=account) | ||||
| except Exception: | except Exception: | ||||
| failed_app_ids.append(app_id) | failed_app_ids.append(app_id) | ||||
| click.echo(click.style("Failed to fix missing site for app {}".format(app_id), fg="red")) | |||||
| logging.exception(f"Failed to fix app related site missing issue, app_id: {app_id}") | |||||
| click.echo(click.style(f"Failed to fix missing site for app {app_id}", fg="red")) | |||||
| logging.exception("Failed to fix app related site missing issue, app_id: %s", app_id) | |||||
| continue | continue | ||||
| if not processed_count: | if not processed_count: |
| case RemoteSettingsSourceName.NACOS: | case RemoteSettingsSourceName.NACOS: | ||||
| remote_source = NacosSettingsSource(current_state) | remote_source = NacosSettingsSource(current_state) | ||||
| case _: | case _: | ||||
| logger.warning(f"Unsupported remote source: {remote_source_name}") | |||||
| logger.warning("Unsupported remote source: %s", remote_source_name) | |||||
| return {} | return {} | ||||
| d: dict[str, Any] = {} | d: dict[str, Any] = {} |
| @computed_field | @computed_field | ||||
| def CELERY_RESULT_BACKEND(self) -> str | None: | def CELERY_RESULT_BACKEND(self) -> str | None: | ||||
| return ( | |||||
| "db+{}".format(self.SQLALCHEMY_DATABASE_URI) | |||||
| if self.CELERY_BACKEND == "database" | |||||
| else self.CELERY_BROKER_URL | |||||
| ) | |||||
| return f"db+{self.SQLALCHEMY_DATABASE_URI}" if self.CELERY_BACKEND == "database" else self.CELERY_BROKER_URL | |||||
| @property | @property | ||||
| def BROKER_USE_SSL(self) -> bool: | def BROKER_USE_SSL(self) -> bool: |
| code, body = http_request(url, timeout=3, headers=self._sign_headers(url)) | code, body = http_request(url, timeout=3, headers=self._sign_headers(url)) | ||||
| if code == 200: | if code == 200: | ||||
| if not body: | if not body: | ||||
| logger.error(f"get_json_from_net load configs failed, body is {body}") | |||||
| logger.error("get_json_from_net load configs failed, body is %s", body) | |||||
| return None | return None | ||||
| data = json.loads(body) | data = json.loads(body) | ||||
| data = data["configurations"] | data = data["configurations"] | ||||
| # if the length is 0 it is returned directly | # if the length is 0 it is returned directly | ||||
| if len(notifications) == 0: | if len(notifications) == 0: | ||||
| return | return | ||||
| url = "{}/notifications/v2".format(self.config_url) | |||||
| url = f"{self.config_url}/notifications/v2" | |||||
| params = { | params = { | ||||
| "appId": self.app_id, | "appId": self.app_id, | ||||
| "cluster": self.cluster, | "cluster": self.cluster, | ||||
| return | return | ||||
| if http_code == 200: | if http_code == 200: | ||||
| if not body: | if not body: | ||||
| logger.error(f"_long_poll load configs failed,body is {body}") | |||||
| logger.error("_long_poll load configs failed,body is %s", body) | |||||
| return | return | ||||
| data = json.loads(body) | data = json.loads(body) | ||||
| for entry in data: | for entry in data: | ||||
| time.sleep(60 * 10) # 10 minutes | time.sleep(60 * 10) # 10 minutes | ||||
| def _do_heart_beat(self, namespace): | def _do_heart_beat(self, namespace): | ||||
| url = "{}/configs/{}/{}/{}?ip={}".format(self.config_url, self.app_id, self.cluster, namespace, self.ip) | |||||
| url = f"{self.config_url}/configs/{self.app_id}/{self.cluster}/{namespace}?ip={self.ip}" | |||||
| try: | try: | ||||
| code, body = http_request(url, timeout=3, headers=self._sign_headers(url)) | code, body = http_request(url, timeout=3, headers=self._sign_headers(url)) | ||||
| if code == 200: | if code == 200: | ||||
| if not body: | if not body: | ||||
| logger.error(f"_do_heart_beat load configs failed,body is {body}") | |||||
| logger.error("_do_heart_beat load configs failed,body is %s", body) | |||||
| return None | return None | ||||
| data = json.loads(body) | data = json.loads(body) | ||||
| if self.last_release_key == data["releaseKey"]: | if self.last_release_key == data["releaseKey"]: |
| def no_key_cache_key(namespace, key): | def no_key_cache_key(namespace, key): | ||||
| return "{}{}{}".format(namespace, len(namespace), key) | |||||
| return f"{namespace}{len(namespace)}{key}" | |||||
| # Returns whether the obtained value is obtained, and None if it does not | # Returns whether the obtained value is obtained, and None if it does not |
| if lang in languages: | if lang in languages: | ||||
| return lang | return lang | ||||
| error = "{lang} is not a valid language.".format(lang=lang) | |||||
| error = f"{lang} is not a valid language." | |||||
| raise ValueError(error) | raise ValueError(error) |
| raise Forbidden() | raise Forbidden() | ||||
| job_id = str(job_id) | job_id = str(job_id) | ||||
| app_annotation_job_key = "{}_app_annotation_job_{}".format(action, str(job_id)) | |||||
| app_annotation_job_key = f"{action}_app_annotation_job_{str(job_id)}" | |||||
| cache_result = redis_client.get(app_annotation_job_key) | cache_result = redis_client.get(app_annotation_job_key) | ||||
| if cache_result is None: | if cache_result is None: | ||||
| raise ValueError("The job does not exist.") | raise ValueError("The job does not exist.") | ||||
| job_status = cache_result.decode() | job_status = cache_result.decode() | ||||
| error_msg = "" | error_msg = "" | ||||
| if job_status == "error": | if job_status == "error": | ||||
| app_annotation_error_key = "{}_app_annotation_error_{}".format(action, str(job_id)) | |||||
| app_annotation_error_key = f"{action}_app_annotation_error_{str(job_id)}" | |||||
| error_msg = redis_client.get(app_annotation_error_key).decode() | error_msg = redis_client.get(app_annotation_error_key).decode() | ||||
| return {"job_id": job_id, "job_status": job_status, "error_msg": error_msg}, 200 | return {"job_id": job_id, "job_status": job_status, "error_msg": error_msg}, 200 | ||||
| raise Forbidden() | raise Forbidden() | ||||
| job_id = str(job_id) | job_id = str(job_id) | ||||
| indexing_cache_key = "app_annotation_batch_import_{}".format(str(job_id)) | |||||
| indexing_cache_key = f"app_annotation_batch_import_{str(job_id)}" | |||||
| cache_result = redis_client.get(indexing_cache_key) | cache_result = redis_client.get(indexing_cache_key) | ||||
| if cache_result is None: | if cache_result is None: | ||||
| raise ValueError("The job does not exist.") | raise ValueError("The job does not exist.") | ||||
| job_status = cache_result.decode() | job_status = cache_result.decode() | ||||
| error_msg = "" | error_msg = "" | ||||
| if job_status == "error": | if job_status == "error": | ||||
| indexing_error_msg_key = "app_annotation_batch_import_error_msg_{}".format(str(job_id)) | |||||
| indexing_error_msg_key = f"app_annotation_batch_import_error_msg_{str(job_id)}" | |||||
| error_msg = redis_client.get(indexing_error_msg_key).decode() | error_msg = redis_client.get(indexing_error_msg_key).decode() | ||||
| return {"job_id": job_id, "job_status": job_status, "error_msg": error_msg}, 200 | return {"job_id": job_id, "job_status": job_status, "error_msg": error_msg}, 200 |
| if args["keyword"]: | if args["keyword"]: | ||||
| query = query.join(Message, Message.conversation_id == Conversation.id).where( | query = query.join(Message, Message.conversation_id == Conversation.id).where( | ||||
| or_( | or_( | ||||
| Message.query.ilike("%{}%".format(args["keyword"])), | |||||
| Message.answer.ilike("%{}%".format(args["keyword"])), | |||||
| Message.query.ilike(f"%{args['keyword']}%"), | |||||
| Message.answer.ilike(f"%{args['keyword']}%"), | |||||
| ) | ) | ||||
| ) | ) | ||||
| query = db.select(Conversation).where(Conversation.app_id == app_model.id) | query = db.select(Conversation).where(Conversation.app_id == app_model.id) | ||||
| if args["keyword"]: | if args["keyword"]: | ||||
| keyword_filter = "%{}%".format(args["keyword"]) | |||||
| keyword_filter = f"%{args['keyword']}%" | |||||
| query = ( | query = ( | ||||
| query.join( | query.join( | ||||
| Message, | Message, |
| oauth_provider.get_access_token(code) | oauth_provider.get_access_token(code) | ||||
| except requests.exceptions.HTTPError as e: | except requests.exceptions.HTTPError as e: | ||||
| logging.exception( | logging.exception( | ||||
| f"An error occurred during the OAuthCallback process with {provider}: {e.response.text}" | |||||
| "An error occurred during the OAuthCallback process with %s: %s", provider, e.response.text | |||||
| ) | ) | ||||
| return {"error": "OAuth data source process failed"}, 400 | return {"error": "OAuth data source process failed"}, 400 | ||||
| try: | try: | ||||
| oauth_provider.sync_data_source(binding_id) | oauth_provider.sync_data_source(binding_id) | ||||
| except requests.exceptions.HTTPError as e: | except requests.exceptions.HTTPError as e: | ||||
| logging.exception(f"An error occurred during the OAuthCallback process with {provider}: {e.response.text}") | |||||
| logging.exception( | |||||
| "An error occurred during the OAuthCallback process with %s: %s", provider, e.response.text | |||||
| ) | |||||
| return {"error": "OAuth data source process failed"}, 400 | return {"error": "OAuth data source process failed"}, 400 | ||||
| return {"result": "success"}, 200 | return {"result": "success"}, 200 |
| user_info = oauth_provider.get_user_info(token) | user_info = oauth_provider.get_user_info(token) | ||||
| except requests.exceptions.RequestException as e: | except requests.exceptions.RequestException as e: | ||||
| error_text = e.response.text if e.response else str(e) | error_text = e.response.text if e.response else str(e) | ||||
| logging.exception(f"An error occurred during the OAuth process with {provider}: {error_text}") | |||||
| logging.exception("An error occurred during the OAuth process with %s: %s", provider, error_text) | |||||
| return {"error": "OAuth process failed"}, 400 | return {"error": "OAuth process failed"}, 400 | ||||
| if invite_token and RegisterService.is_valid_invite_token(invite_token): | if invite_token and RegisterService.is_valid_invite_token(invite_token): |
| raise DocumentAlreadyFinishedError() | raise DocumentAlreadyFinishedError() | ||||
| retry_documents.append(document) | retry_documents.append(document) | ||||
| except Exception: | except Exception: | ||||
| logging.exception(f"Failed to retry document, document id: {document_id}") | |||||
| logging.exception("Failed to retry document, document id: %s", document_id) | |||||
| continue | continue | ||||
| # retry document | # retry document | ||||
| DocumentService.retry_document(dataset_id, retry_documents) | DocumentService.retry_document(dataset_id, retry_documents) |
| raise ProviderNotInitializeError(ex.description) | raise ProviderNotInitializeError(ex.description) | ||||
| segment_ids = request.args.getlist("segment_id") | segment_ids = request.args.getlist("segment_id") | ||||
| document_indexing_cache_key = "document_{}_indexing".format(document.id) | |||||
| document_indexing_cache_key = f"document_{document.id}_indexing" | |||||
| cache_result = redis_client.get(document_indexing_cache_key) | cache_result = redis_client.get(document_indexing_cache_key) | ||||
| if cache_result is not None: | if cache_result is not None: | ||||
| raise InvalidActionError("Document is being indexed, please try again later") | raise InvalidActionError("Document is being indexed, please try again later") | ||||
| raise ValueError("The CSV file is empty.") | raise ValueError("The CSV file is empty.") | ||||
| # async job | # async job | ||||
| job_id = str(uuid.uuid4()) | job_id = str(uuid.uuid4()) | ||||
| indexing_cache_key = "segment_batch_import_{}".format(str(job_id)) | |||||
| indexing_cache_key = f"segment_batch_import_{str(job_id)}" | |||||
| # send batch add segments task | # send batch add segments task | ||||
| redis_client.setnx(indexing_cache_key, "waiting") | redis_client.setnx(indexing_cache_key, "waiting") | ||||
| batch_create_segment_to_index_task.delay( | batch_create_segment_to_index_task.delay( | ||||
| @account_initialization_required | @account_initialization_required | ||||
| def get(self, job_id): | def get(self, job_id): | ||||
| job_id = str(job_id) | job_id = str(job_id) | ||||
| indexing_cache_key = "segment_batch_import_{}".format(job_id) | |||||
| indexing_cache_key = f"segment_batch_import_{job_id}" | |||||
| cache_result = redis_client.get(indexing_cache_key) | cache_result = redis_client.get(indexing_cache_key) | ||||
| if cache_result is None: | if cache_result is None: | ||||
| raise ValueError("The job does not exist.") | raise ValueError("The job does not exist.") |
| ): | ): | ||||
| res.append(installed_app) | res.append(installed_app) | ||||
| installed_app_list = res | installed_app_list = res | ||||
| logger.debug(f"installed_app_list: {installed_app_list}, user_id: {user_id}") | |||||
| logger.debug("installed_app_list: %s, user_id: %s", installed_app_list, user_id) | |||||
| installed_app_list.sort( | installed_app_list.sort( | ||||
| key=lambda app: ( | key=lambda app: ( |
| try: | try: | ||||
| response = requests.get(check_update_url, {"current_version": args.get("current_version")}) | response = requests.get(check_update_url, {"current_version": args.get("current_version")}) | ||||
| except Exception as error: | except Exception as error: | ||||
| logging.warning("Check update version error: {}.".format(str(error))) | |||||
| logging.warning("Check update version error: %s.", str(error)) | |||||
| result["version"] = args.get("current_version") | result["version"] = args.get("current_version") | ||||
| return result | return result | ||||
| # Compare versions | # Compare versions | ||||
| return latest > current | return latest > current | ||||
| except version.InvalidVersion: | except version.InvalidVersion: | ||||
| logging.warning(f"Invalid version format: latest={latest_version}, current={current_version}") | |||||
| logging.warning("Invalid version format: latest=%s, current=%s", latest_version, current_version) | |||||
| return False | return False | ||||
| ) | ) | ||||
| except Exception as ex: | except Exception as ex: | ||||
| logging.exception( | logging.exception( | ||||
| f"Failed to update default model, model type: {model_setting['model_type']}," | |||||
| f" model:{model_setting.get('model')}" | |||||
| "Failed to update default model, model type: %s, model: %s", | |||||
| model_setting["model_type"], | |||||
| model_setting.get("model"), | |||||
| ) | ) | ||||
| raise ex | raise ex | ||||
| ) | ) | ||||
| except CredentialsValidateFailedError as ex: | except CredentialsValidateFailedError as ex: | ||||
| logging.exception( | logging.exception( | ||||
| f"Failed to save model credentials, tenant_id: {tenant_id}," | |||||
| f" model: {args.get('model')}, model_type: {args.get('model_type')}" | |||||
| "Failed to save model credentials, tenant_id: %s, model: %s, model_type: %s", | |||||
| tenant_id, | |||||
| args.get("model"), | |||||
| args.get("model_type"), | |||||
| ) | ) | ||||
| raise ValueError(str(ex)) | raise ValueError(str(ex)) | ||||
| @validate_app_token | @validate_app_token | ||||
| def get(self, app_model: App, job_id, action): | def get(self, app_model: App, job_id, action): | ||||
| job_id = str(job_id) | job_id = str(job_id) | ||||
| app_annotation_job_key = "{}_app_annotation_job_{}".format(action, str(job_id)) | |||||
| app_annotation_job_key = f"{action}_app_annotation_job_{str(job_id)}" | |||||
| cache_result = redis_client.get(app_annotation_job_key) | cache_result = redis_client.get(app_annotation_job_key) | ||||
| if cache_result is None: | if cache_result is None: | ||||
| raise ValueError("The job does not exist.") | raise ValueError("The job does not exist.") | ||||
| job_status = cache_result.decode() | job_status = cache_result.decode() | ||||
| error_msg = "" | error_msg = "" | ||||
| if job_status == "error": | if job_status == "error": | ||||
| app_annotation_error_key = "{}_app_annotation_error_{}".format(action, str(job_id)) | |||||
| app_annotation_error_key = f"{action}_app_annotation_error_{str(job_id)}" | |||||
| error_msg = redis_client.get(app_annotation_error_key).decode() | error_msg = redis_client.get(app_annotation_error_key).decode() | ||||
| return {"job_id": job_id, "job_status": job_status, "error_msg": error_msg}, 200 | return {"job_id": job_id, "job_status": job_status, "error_msg": error_msg}, 200 |
| if len(e.args) > 0 and e.args[0] == "I/O operation on closed file.": # ignore this error | if len(e.args) > 0 and e.args[0] == "I/O operation on closed file.": # ignore this error | ||||
| raise GenerateTaskStoppedError() | raise GenerateTaskStoppedError() | ||||
| else: | else: | ||||
| logger.exception(f"Failed to process generate task pipeline, conversation_id: {conversation.id}") | |||||
| logger.exception("Failed to process generate task pipeline, conversation_id: %s", conversation.id) | |||||
| raise e | raise e |
| start_listener_time = time.time() | start_listener_time = time.time() | ||||
| yield MessageAudioStreamResponse(audio=audio_trunk.audio, task_id=task_id) | yield MessageAudioStreamResponse(audio=audio_trunk.audio, task_id=task_id) | ||||
| except Exception: | except Exception: | ||||
| logger.exception(f"Failed to listen audio message, task_id: {task_id}") | |||||
| logger.exception("Failed to listen audio message, task_id: %s", task_id) | |||||
| break | break | ||||
| if tts_publisher: | if tts_publisher: | ||||
| yield MessageAudioEndStreamResponse(audio="", task_id=task_id) | yield MessageAudioEndStreamResponse(audio="", task_id=task_id) |
| if len(e.args) > 0 and e.args[0] == "I/O operation on closed file.": # ignore this error | if len(e.args) > 0 and e.args[0] == "I/O operation on closed file.": # ignore this error | ||||
| raise GenerateTaskStoppedError() | raise GenerateTaskStoppedError() | ||||
| else: | else: | ||||
| logger.exception(f"Failed to handle response, conversation_id: {conversation.id}") | |||||
| logger.exception("Failed to handle response, conversation_id: %s", conversation.id) | |||||
| raise e | raise e | ||||
| def _get_app_model_config(self, app_model: App, conversation: Optional[Conversation] = None) -> AppModelConfig: | def _get_app_model_config(self, app_model: App, conversation: Optional[Conversation] = None) -> AppModelConfig: |
| try: | try: | ||||
| runner.run() | runner.run() | ||||
| except GenerateTaskStoppedError as e: | except GenerateTaskStoppedError as e: | ||||
| logger.warning(f"Task stopped: {str(e)}") | |||||
| logger.warning("Task stopped: %s", str(e)) | |||||
| pass | pass | ||||
| except InvokeAuthorizationError: | except InvokeAuthorizationError: | ||||
| queue_manager.publish_error( | queue_manager.publish_error( | ||||
| raise GenerateTaskStoppedError() | raise GenerateTaskStoppedError() | ||||
| else: | else: | ||||
| logger.exception( | logger.exception( | ||||
| f"Fails to process generate task pipeline, task_id: {application_generate_entity.task_id}" | |||||
| "Fails to process generate task pipeline, task_id: %s", application_generate_entity.task_id | |||||
| ) | ) | ||||
| raise e | raise e |
| else: | else: | ||||
| yield MessageAudioStreamResponse(audio=audio_trunk.audio, task_id=task_id) | yield MessageAudioStreamResponse(audio=audio_trunk.audio, task_id=task_id) | ||||
| except Exception: | except Exception: | ||||
| logger.exception(f"Fails to get audio trunk, task_id: {task_id}") | |||||
| logger.exception("Fails to get audio trunk, task_id: %s", task_id) | |||||
| break | break | ||||
| if tts_publisher: | if tts_publisher: | ||||
| yield MessageAudioEndStreamResponse(audio="", task_id=task_id) | yield MessageAudioEndStreamResponse(audio="", task_id=task_id) |
| return annotation | return annotation | ||||
| except Exception as e: | except Exception as e: | ||||
| logger.warning(f"Query annotation failed, exception: {str(e)}.") | |||||
| logger.warning("Query annotation failed, exception: %s.", str(e)) | |||||
| return None | return None | ||||
| return None | return None |
| conversation.name = name | conversation.name = name | ||||
| except Exception as e: | except Exception as e: | ||||
| if dify_config.DEBUG: | if dify_config.DEBUG: | ||||
| logging.exception(f"generate conversation name failed, conversation_id: {conversation_id}") | |||||
| logging.exception("generate conversation name failed, conversation_id: %s", conversation_id) | |||||
| pass | pass | ||||
| db.session.merge(conversation) | db.session.merge(conversation) |
| credentials=copy_credentials, | credentials=copy_credentials, | ||||
| ) | ) | ||||
| except Exception as ex: | except Exception as ex: | ||||
| logger.warning(f"get custom model schema failed, {ex}") | |||||
| logger.warning("get custom model schema failed, %s", ex) | |||||
| continue | continue | ||||
| if not custom_model_schema: | if not custom_model_schema: | ||||
| credentials=model_configuration.credentials, | credentials=model_configuration.credentials, | ||||
| ) | ) | ||||
| except Exception as ex: | except Exception as ex: | ||||
| logger.warning(f"get custom model schema failed, {ex}") | |||||
| logger.warning("get custom model schema failed, %s", ex) | |||||
| continue | continue | ||||
| if not custom_model_schema: | if not custom_model_schema: |
| :param params: the request params | :param params: the request params | ||||
| :return: the response json | :return: the response json | ||||
| """ | """ | ||||
| headers = {"Content-Type": "application/json", "Authorization": "Bearer {}".format(self.api_key)} | |||||
| headers = {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"} | |||||
| url = self.api_endpoint | url = self.api_endpoint | ||||
| raise ValueError("request connection error") | raise ValueError("request connection error") | ||||
| if response.status_code != 200: | if response.status_code != 200: | ||||
| raise ValueError( | |||||
| "request error, status_code: {}, content: {}".format(response.status_code, response.text[:100]) | |||||
| ) | |||||
| raise ValueError(f"request error, status_code: {response.status_code}, content: {response.text[:100]}") | |||||
| return cast(dict, response.json()) | return cast(dict, response.json()) |
| # Check for extension module file | # Check for extension module file | ||||
| if (extension_name + ".py") not in file_names: | if (extension_name + ".py") not in file_names: | ||||
| logging.warning(f"Missing {extension_name}.py file in {subdir_path}, Skip.") | |||||
| logging.warning("Missing %s.py file in %s, Skip.", extension_name, subdir_path) | |||||
| continue | continue | ||||
| # Check for builtin flag and position | # Check for builtin flag and position | ||||
| break | break | ||||
| if not extension_class: | if not extension_class: | ||||
| logging.warning(f"Missing subclass of {cls.__name__} in {module_name}, Skip.") | |||||
| logging.warning("Missing subclass of %s in %s, Skip.", cls.__name__, module_name) | |||||
| continue | continue | ||||
| # Load schema if not builtin | # Load schema if not builtin | ||||
| if not builtin: | if not builtin: | ||||
| json_path = os.path.join(subdir_path, "schema.json") | json_path = os.path.join(subdir_path, "schema.json") | ||||
| if not os.path.exists(json_path): | if not os.path.exists(json_path): | ||||
| logging.warning(f"Missing schema.json file in {subdir_path}, Skip.") | |||||
| logging.warning("Missing schema.json file in %s, Skip.", subdir_path) | |||||
| continue | continue | ||||
| with open(json_path, encoding="utf-8") as f: | with open(json_path, encoding="utf-8") as f: |
| """ | """ | ||||
| # get params from config | # get params from config | ||||
| if not self.config: | if not self.config: | ||||
| raise ValueError("config is required, config: {}".format(self.config)) | |||||
| raise ValueError(f"config is required, config: {self.config}") | |||||
| api_based_extension_id = self.config.get("api_based_extension_id") | api_based_extension_id = self.config.get("api_based_extension_id") | ||||
| assert api_based_extension_id is not None, "api_based_extension_id is required" | assert api_based_extension_id is not None, "api_based_extension_id is required" | ||||
| # request api | # request api | ||||
| requestor = APIBasedExtensionRequestor(api_endpoint=api_based_extension.api_endpoint, api_key=api_key) | requestor = APIBasedExtensionRequestor(api_endpoint=api_based_extension.api_endpoint, api_key=api_key) | ||||
| except Exception as e: | except Exception as e: | ||||
| raise ValueError("[External data tool] API query failed, variable: {}, error: {}".format(self.variable, e)) | |||||
| raise ValueError(f"[External data tool] API query failed, variable: {self.variable}, error: {e}") | |||||
| response_json = requestor.request( | response_json = requestor.request( | ||||
| point=APIBasedExtensionPoint.APP_EXTERNAL_DATA_TOOL_QUERY, | point=APIBasedExtensionPoint.APP_EXTERNAL_DATA_TOOL_QUERY, | ||||
| if not isinstance(response_json["result"], str): | if not isinstance(response_json["result"], str): | ||||
| raise ValueError( | raise ValueError( | ||||
| "[External data tool] API query failed, variable: {}, error: result is not string".format(self.variable) | |||||
| f"[External data tool] API query failed, variable: {self.variable}, error: result is not string" | |||||
| ) | ) | ||||
| return response_json["result"] | return response_json["result"] |
| if moderation_result is True: | if moderation_result is True: | ||||
| return True | return True | ||||
| except Exception: | except Exception: | ||||
| logger.exception(f"Fails to check moderation, provider_name: {provider_name}") | |||||
| logger.exception("Fails to check moderation, provider_name: %s", provider_name) | |||||
| raise InvokeBadRequestError("Rate limit exceeded, please try again later.") | raise InvokeBadRequestError("Rate limit exceeded, please try again later.") | ||||
| return False | return False |
| spec.loader.exec_module(module) | spec.loader.exec_module(module) | ||||
| return module | return module | ||||
| except Exception as e: | except Exception as e: | ||||
| logging.exception(f"Failed to load module {module_name} from script file '{py_file_path!r}'") | |||||
| logging.exception("Failed to load module %s from script file '%s'", module_name, repr(py_file_path)) | |||||
| raise e | raise e | ||||
| if response.status_code not in STATUS_FORCELIST: | if response.status_code not in STATUS_FORCELIST: | ||||
| return response | return response | ||||
| else: | else: | ||||
| logging.warning(f"Received status code {response.status_code} for URL {url} which is in the force list") | |||||
| logging.warning( | |||||
| "Received status code %s for URL %s which is in the force list", response.status_code, url | |||||
| ) | |||||
| except httpx.RequestError as e: | except httpx.RequestError as e: | ||||
| logging.warning(f"Request to URL {url} failed on attempt {retries + 1}: {e}") | |||||
| logging.warning("Request to URL %s failed on attempt %s: %s", url, retries + 1, e) | |||||
| if max_retries == 0: | if max_retries == 0: | ||||
| raise | raise | ||||
| documents=documents, | documents=documents, | ||||
| ) | ) | ||||
| except DocumentIsPausedError: | except DocumentIsPausedError: | ||||
| raise DocumentIsPausedError("Document paused, document id: {}".format(dataset_document.id)) | |||||
| raise DocumentIsPausedError(f"Document paused, document id: {dataset_document.id}") | |||||
| except ProviderTokenNotInitError as e: | except ProviderTokenNotInitError as e: | ||||
| dataset_document.indexing_status = "error" | dataset_document.indexing_status = "error" | ||||
| dataset_document.error = str(e.description) | dataset_document.error = str(e.description) | ||||
| dataset_document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) | dataset_document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) | ||||
| db.session.commit() | db.session.commit() | ||||
| except ObjectDeletedError: | except ObjectDeletedError: | ||||
| logging.warning("Document deleted, document id: {}".format(dataset_document.id)) | |||||
| logging.warning("Document deleted, document id: %s", dataset_document.id) | |||||
| except Exception as e: | except Exception as e: | ||||
| logging.exception("consume document failed") | logging.exception("consume document failed") | ||||
| dataset_document.indexing_status = "error" | dataset_document.indexing_status = "error" | ||||
| index_processor=index_processor, dataset=dataset, dataset_document=dataset_document, documents=documents | index_processor=index_processor, dataset=dataset, dataset_document=dataset_document, documents=documents | ||||
| ) | ) | ||||
| except DocumentIsPausedError: | except DocumentIsPausedError: | ||||
| raise DocumentIsPausedError("Document paused, document id: {}".format(dataset_document.id)) | |||||
| raise DocumentIsPausedError(f"Document paused, document id: {dataset_document.id}") | |||||
| except ProviderTokenNotInitError as e: | except ProviderTokenNotInitError as e: | ||||
| dataset_document.indexing_status = "error" | dataset_document.indexing_status = "error" | ||||
| dataset_document.error = str(e.description) | dataset_document.error = str(e.description) | ||||
| index_processor=index_processor, dataset=dataset, dataset_document=dataset_document, documents=documents | index_processor=index_processor, dataset=dataset, dataset_document=dataset_document, documents=documents | ||||
| ) | ) | ||||
| except DocumentIsPausedError: | except DocumentIsPausedError: | ||||
| raise DocumentIsPausedError("Document paused, document id: {}".format(dataset_document.id)) | |||||
| raise DocumentIsPausedError(f"Document paused, document id: {dataset_document.id}") | |||||
| except ProviderTokenNotInitError as e: | except ProviderTokenNotInitError as e: | ||||
| dataset_document.indexing_status = "error" | dataset_document.indexing_status = "error" | ||||
| dataset_document.error = str(e.description) | dataset_document.error = str(e.description) | ||||
| except Exception: | except Exception: | ||||
| logging.exception( | logging.exception( | ||||
| "Delete image_files failed while indexing_estimate, \ | "Delete image_files failed while indexing_estimate, \ | ||||
| image_upload_file_is: {}".format(upload_file_id) | |||||
| image_upload_file_is: %s", | |||||
| upload_file_id, | |||||
| ) | ) | ||||
| db.session.delete(image_file) | db.session.delete(image_file) | ||||
| @staticmethod | @staticmethod | ||||
| def _check_document_paused_status(document_id: str): | def _check_document_paused_status(document_id: str): | ||||
| indexing_cache_key = "document_{}_is_paused".format(document_id) | |||||
| indexing_cache_key = f"document_{document_id}_is_paused" | |||||
| result = redis_client.get(indexing_cache_key) | result = redis_client.get(indexing_cache_key) | ||||
| if result: | if result: | ||||
| raise DocumentIsPausedError() | raise DocumentIsPausedError() |
| error = str(e) | error = str(e) | ||||
| error_step = "generate rule config" | error_step = "generate rule config" | ||||
| except Exception as e: | except Exception as e: | ||||
| logging.exception(f"Failed to generate rule config, model: {model_config.get('name')}") | |||||
| logging.exception("Failed to generate rule config, model: %s", model_config.get("name")) | |||||
| rule_config["error"] = str(e) | rule_config["error"] = str(e) | ||||
| rule_config["error"] = f"Failed to {error_step}. Error: {error}" if error else "" | rule_config["error"] = f"Failed to {error_step}. Error: {error}" if error else "" | ||||
| error_step = "generate conversation opener" | error_step = "generate conversation opener" | ||||
| except Exception as e: | except Exception as e: | ||||
| logging.exception(f"Failed to generate rule config, model: {model_config.get('name')}") | |||||
| logging.exception("Failed to generate rule config, model: %s", model_config.get("name")) | |||||
| rule_config["error"] = str(e) | rule_config["error"] = str(e) | ||||
| rule_config["error"] = f"Failed to {error_step}. Error: {error}" if error else "" | rule_config["error"] = f"Failed to {error_step}. Error: {error}" if error else "" | ||||
| return {"code": "", "language": code_language, "error": f"Failed to generate code. Error: {error}"} | return {"code": "", "language": code_language, "error": f"Failed to generate code. Error: {error}"} | ||||
| except Exception as e: | except Exception as e: | ||||
| logging.exception( | logging.exception( | ||||
| f"Failed to invoke LLM model, model: {model_config.get('name')}, language: {code_language}" | |||||
| "Failed to invoke LLM model, model: %s, language: %s", model_config.get("name"), code_language | |||||
| ) | ) | ||||
| return {"code": "", "language": code_language, "error": f"An unexpected error occurred: {str(e)}"} | return {"code": "", "language": code_language, "error": f"An unexpected error occurred: {str(e)}"} | ||||
| error = str(e) | error = str(e) | ||||
| return {"output": "", "error": f"Failed to generate JSON Schema. Error: {error}"} | return {"output": "", "error": f"Failed to generate JSON Schema. Error: {error}"} | ||||
| except Exception as e: | except Exception as e: | ||||
| logging.exception(f"Failed to invoke LLM model, model: {model_config.get('name')}") | |||||
| logging.exception("Failed to invoke LLM model, model: %s", model_config.get("name")) | |||||
| return {"output": "", "error": f"An unexpected error occurred: {str(e)}"} | return {"output": "", "error": f"An unexpected error occurred: {str(e)}"} |
| status_queue: Queue to put status updates. | status_queue: Queue to put status updates. | ||||
| """ | """ | ||||
| endpoint_url = urljoin(self.url, sse_data) | endpoint_url = urljoin(self.url, sse_data) | ||||
| logger.info(f"Received endpoint URL: {endpoint_url}") | |||||
| logger.info("Received endpoint URL: %s", endpoint_url) | |||||
| if not self._validate_endpoint_url(endpoint_url): | if not self._validate_endpoint_url(endpoint_url): | ||||
| error_msg = f"Endpoint origin does not match connection origin: {endpoint_url}" | error_msg = f"Endpoint origin does not match connection origin: {endpoint_url}" | ||||
| """ | """ | ||||
| try: | try: | ||||
| message = types.JSONRPCMessage.model_validate_json(sse_data) | message = types.JSONRPCMessage.model_validate_json(sse_data) | ||||
| logger.debug(f"Received server message: {message}") | |||||
| logger.debug("Received server message: %s", message) | |||||
| session_message = SessionMessage(message) | session_message = SessionMessage(message) | ||||
| read_queue.put(session_message) | read_queue.put(session_message) | ||||
| except Exception as exc: | except Exception as exc: | ||||
| case "message": | case "message": | ||||
| self._handle_message_event(sse.data, read_queue) | self._handle_message_event(sse.data, read_queue) | ||||
| case _: | case _: | ||||
| logger.warning(f"Unknown SSE event: {sse.event}") | |||||
| logger.warning("Unknown SSE event: %s", sse.event) | |||||
| def sse_reader(self, event_source, read_queue: ReadQueue, status_queue: StatusQueue) -> None: | def sse_reader(self, event_source, read_queue: ReadQueue, status_queue: StatusQueue) -> None: | ||||
| """Read and process SSE events. | """Read and process SSE events. | ||||
| for sse in event_source.iter_sse(): | for sse in event_source.iter_sse(): | ||||
| self._handle_sse_event(sse, read_queue, status_queue) | self._handle_sse_event(sse, read_queue, status_queue) | ||||
| except httpx.ReadError as exc: | except httpx.ReadError as exc: | ||||
| logger.debug(f"SSE reader shutting down normally: {exc}") | |||||
| logger.debug("SSE reader shutting down normally: %s", exc) | |||||
| except Exception as exc: | except Exception as exc: | ||||
| read_queue.put(exc) | read_queue.put(exc) | ||||
| finally: | finally: | ||||
| ), | ), | ||||
| ) | ) | ||||
| response.raise_for_status() | response.raise_for_status() | ||||
| logger.debug(f"Client message sent successfully: {response.status_code}") | |||||
| logger.debug("Client message sent successfully: %s", response.status_code) | |||||
| def post_writer(self, client: httpx.Client, endpoint_url: str, write_queue: WriteQueue) -> None: | def post_writer(self, client: httpx.Client, endpoint_url: str, write_queue: WriteQueue) -> None: | ||||
| """Handle writing messages to the server. | """Handle writing messages to the server. | ||||
| except queue.Empty: | except queue.Empty: | ||||
| continue | continue | ||||
| except httpx.ReadError as exc: | except httpx.ReadError as exc: | ||||
| logger.debug(f"Post writer shutting down normally: {exc}") | |||||
| logger.debug("Post writer shutting down normally: %s", exc) | |||||
| except Exception as exc: | except Exception as exc: | ||||
| logger.exception("Error writing messages") | logger.exception("Error writing messages") | ||||
| write_queue.put(exc) | write_queue.put(exc) | ||||
| ), | ), | ||||
| ) | ) | ||||
| response.raise_for_status() | response.raise_for_status() | ||||
| logger.debug(f"Client message sent successfully: {response.status_code}") | |||||
| logger.debug("Client message sent successfully: %s", response.status_code) | |||||
| except Exception as exc: | except Exception as exc: | ||||
| logger.exception("Error sending message") | logger.exception("Error sending message") | ||||
| raise | raise | ||||
| if sse.event == "message": | if sse.event == "message": | ||||
| try: | try: | ||||
| message = types.JSONRPCMessage.model_validate_json(sse.data) | message = types.JSONRPCMessage.model_validate_json(sse.data) | ||||
| logger.debug(f"Received server message: {message}") | |||||
| logger.debug("Received server message: %s", message) | |||||
| yield SessionMessage(message) | yield SessionMessage(message) | ||||
| except Exception as exc: | except Exception as exc: | ||||
| logger.exception("Error parsing server message") | logger.exception("Error parsing server message") | ||||
| yield exc | yield exc | ||||
| else: | else: | ||||
| logger.warning(f"Unknown SSE event: {sse.event}") | |||||
| logger.warning("Unknown SSE event: %s", sse.event) | |||||
| except Exception as exc: | except Exception as exc: | ||||
| logger.exception("Error reading SSE messages") | logger.exception("Error reading SSE messages") | ||||
| yield exc | yield exc |
| new_session_id = response.headers.get(MCP_SESSION_ID) | new_session_id = response.headers.get(MCP_SESSION_ID) | ||||
| if new_session_id: | if new_session_id: | ||||
| self.session_id = new_session_id | self.session_id = new_session_id | ||||
| logger.info(f"Received session ID: {self.session_id}") | |||||
| logger.info("Received session ID: %s", self.session_id) | |||||
| def _handle_sse_event( | def _handle_sse_event( | ||||
| self, | self, | ||||
| if sse.event == "message": | if sse.event == "message": | ||||
| try: | try: | ||||
| message = JSONRPCMessage.model_validate_json(sse.data) | message = JSONRPCMessage.model_validate_json(sse.data) | ||||
| logger.debug(f"SSE message: {message}") | |||||
| logger.debug("SSE message: %s", message) | |||||
| # If this is a response and we have original_request_id, replace it | # If this is a response and we have original_request_id, replace it | ||||
| if original_request_id is not None and isinstance(message.root, JSONRPCResponse | JSONRPCError): | if original_request_id is not None and isinstance(message.root, JSONRPCResponse | JSONRPCError): | ||||
| logger.debug("Received ping event") | logger.debug("Received ping event") | ||||
| return False | return False | ||||
| else: | else: | ||||
| logger.warning(f"Unknown SSE event: {sse.event}") | |||||
| logger.warning("Unknown SSE event: %s", sse.event) | |||||
| return False | return False | ||||
| def handle_get_stream( | def handle_get_stream( | ||||
| self._handle_sse_event(sse, server_to_client_queue) | self._handle_sse_event(sse, server_to_client_queue) | ||||
| except Exception as exc: | except Exception as exc: | ||||
| logger.debug(f"GET stream error (non-fatal): {exc}") | |||||
| logger.debug("GET stream error (non-fatal): %s", exc) | |||||
| def _handle_resumption_request(self, ctx: RequestContext) -> None: | def _handle_resumption_request(self, ctx: RequestContext) -> None: | ||||
| """Handle a resumption request using GET with SSE.""" | """Handle a resumption request using GET with SSE.""" | ||||
| # Check if this is a resumption request | # Check if this is a resumption request | ||||
| is_resumption = bool(metadata and metadata.resumption_token) | is_resumption = bool(metadata and metadata.resumption_token) | ||||
| logger.debug(f"Sending client message: {message}") | |||||
| logger.debug("Sending client message: %s", message) | |||||
| # Handle initialized notification | # Handle initialized notification | ||||
| if self._is_initialized_notification(message): | if self._is_initialized_notification(message): | ||||
| if response.status_code == 405: | if response.status_code == 405: | ||||
| logger.debug("Server does not allow session termination") | logger.debug("Server does not allow session termination") | ||||
| elif response.status_code != 200: | elif response.status_code != 200: | ||||
| logger.warning(f"Session termination failed: {response.status_code}") | |||||
| logger.warning("Session termination failed: %s", response.status_code) | |||||
| except Exception as exc: | except Exception as exc: | ||||
| logger.warning(f"Session termination failed: {exc}") | |||||
| logger.warning("Session termination failed: %s", exc) | |||||
| def get_session_id(self) -> str | None: | def get_session_id(self) -> str | None: | ||||
| """Get the current session ID.""" | """Get the current session ID.""" |
| self.connect_server(client_factory, method_name) | self.connect_server(client_factory, method_name) | ||||
| else: | else: | ||||
| try: | try: | ||||
| logger.debug(f"Not supported method {method_name} found in URL path, trying default 'mcp' method.") | |||||
| logger.debug("Not supported method %s found in URL path, trying default 'mcp' method.", method_name) | |||||
| self.connect_server(sse_client, "sse") | self.connect_server(sse_client, "sse") | ||||
| except MCPConnectionError: | except MCPConnectionError: | ||||
| logger.debug("MCP connection failed with 'sse', falling back to 'mcp' method.") | logger.debug("MCP connection failed with 'sse', falling back to 'mcp' method.") |
| self._handle_incoming(notification) | self._handle_incoming(notification) | ||||
| except Exception as e: | except Exception as e: | ||||
| # For other validation errors, log and continue | # For other validation errors, log and continue | ||||
| logging.warning(f"Failed to validate notification: {e}. Message was: {message.message.root}") | |||||
| logging.warning("Failed to validate notification: %s. Message was: %s", e, message.message.root) | |||||
| else: # Response or error | else: # Response or error | ||||
| response_queue = self._response_streams.get(message.message.root.id) | response_queue = self._response_streams.get(message.message.root.id) | ||||
| if response_queue is not None: | if response_queue is not None: |
| if dify_config.DEBUG: | if dify_config.DEBUG: | ||||
| logger.info( | logger.info( | ||||
| f"Model LB\nid: {config.id}\nname:{config.name}\n" | |||||
| f"tenant_id: {self._tenant_id}\nprovider: {self._provider}\n" | |||||
| f"model_type: {self._model_type.value}\nmodel: {self._model}" | |||||
| """Model LB | |||||
| id: %s | |||||
| name:%s | |||||
| tenant_id: %s | |||||
| provider: %s | |||||
| model_type: %s | |||||
| model: %s""", | |||||
| config.id, | |||||
| config.name, | |||||
| self._tenant_id, | |||||
| self._provider, | |||||
| self._model_type.value, | |||||
| self._model, | |||||
| ) | ) | ||||
| return config | return config |
| if callback.raise_error: | if callback.raise_error: | ||||
| raise e | raise e | ||||
| else: | else: | ||||
| logger.warning(f"Callback {callback.__class__.__name__} on_before_invoke failed with error {e}") | |||||
| logger.warning( | |||||
| "Callback %s on_before_invoke failed with error %s", callback.__class__.__name__, e | |||||
| ) | |||||
| def _trigger_new_chunk_callbacks( | def _trigger_new_chunk_callbacks( | ||||
| self, | self, | ||||
| if callback.raise_error: | if callback.raise_error: | ||||
| raise e | raise e | ||||
| else: | else: | ||||
| logger.warning(f"Callback {callback.__class__.__name__} on_new_chunk failed with error {e}") | |||||
| logger.warning("Callback %s on_new_chunk failed with error %s", callback.__class__.__name__, e) | |||||
| def _trigger_after_invoke_callbacks( | def _trigger_after_invoke_callbacks( | ||||
| self, | self, | ||||
| if callback.raise_error: | if callback.raise_error: | ||||
| raise e | raise e | ||||
| else: | else: | ||||
| logger.warning(f"Callback {callback.__class__.__name__} on_after_invoke failed with error {e}") | |||||
| logger.warning( | |||||
| "Callback %s on_after_invoke failed with error %s", callback.__class__.__name__, e | |||||
| ) | |||||
| def _trigger_invoke_error_callbacks( | def _trigger_invoke_error_callbacks( | ||||
| self, | self, | ||||
| if callback.raise_error: | if callback.raise_error: | ||||
| raise e | raise e | ||||
| else: | else: | ||||
| logger.warning(f"Callback {callback.__class__.__name__} on_invoke_error failed with error {e}") | |||||
| logger.warning( | |||||
| "Callback %s on_invoke_error failed with error %s", callback.__class__.__name__, e | |||||
| ) |
| result: ModerationOutputsResult = moderation_factory.moderation_for_outputs(moderation_buffer) | result: ModerationOutputsResult = moderation_factory.moderation_for_outputs(moderation_buffer) | ||||
| return result | return result | ||||
| except Exception as e: | except Exception as e: | ||||
| logger.exception(f"Moderation Output error, app_id: {app_id}") | |||||
| logger.exception("Moderation Output error, app_id: %s", app_id) | |||||
| return None | return None |
| try: | try: | ||||
| return self.trace_client.get_project_url() | return self.trace_client.get_project_url() | ||||
| except Exception as e: | except Exception as e: | ||||
| logger.info(f"Aliyun get run url failed: {str(e)}", exc_info=True) | |||||
| logger.info("Aliyun get run url failed: %s", str(e), exc_info=True) | |||||
| raise ValueError(f"Aliyun get run url failed: {str(e)}") | raise ValueError(f"Aliyun get run url failed: {str(e)}") | ||||
| def workflow_trace(self, trace_info: WorkflowTraceInfo): | def workflow_trace(self, trace_info: WorkflowTraceInfo): | ||||
| node_span = self.build_workflow_task_span(trace_id, workflow_span_id, trace_info, node_execution) | node_span = self.build_workflow_task_span(trace_id, workflow_span_id, trace_info, node_execution) | ||||
| return node_span | return node_span | ||||
| except Exception as e: | except Exception as e: | ||||
| logging.debug(f"Error occurred in build_workflow_node_span: {e}", exc_info=True) | |||||
| logging.debug("Error occurred in build_workflow_node_span: %s", e, exc_info=True) | |||||
| return None | return None | ||||
| def get_workflow_node_status(self, node_execution: WorkflowNodeExecution) -> Status: | def get_workflow_node_status(self, node_execution: WorkflowNodeExecution) -> Status: |
| if response.status_code == 405: | if response.status_code == 405: | ||||
| return True | return True | ||||
| else: | else: | ||||
| logger.debug(f"AliyunTrace API check failed: Unexpected status code: {response.status_code}") | |||||
| logger.debug("AliyunTrace API check failed: Unexpected status code: %s", response.status_code) | |||||
| return False | return False | ||||
| except requests.exceptions.RequestException as e: | except requests.exceptions.RequestException as e: | ||||
| logger.debug(f"AliyunTrace API check failed: {str(e)}") | |||||
| logger.debug("AliyunTrace API check failed: %s", str(e)) | |||||
| raise ValueError(f"AliyunTrace API check failed: {str(e)}") | raise ValueError(f"AliyunTrace API check failed: {str(e)}") | ||||
| def get_project_url(self): | def get_project_url(self): | ||||
| try: | try: | ||||
| self.exporter.export(spans_to_export) | self.exporter.export(spans_to_export) | ||||
| except Exception as e: | except Exception as e: | ||||
| logger.debug(f"Error exporting spans: {e}") | |||||
| logger.debug("Error exporting spans: %s", e) | |||||
| def shutdown(self): | def shutdown(self): | ||||
| with self.condition: | with self.condition: |
| # Create a named tracer instead of setting the global provider | # Create a named tracer instead of setting the global provider | ||||
| tracer_name = f"arize_phoenix_tracer_{arize_phoenix_config.project}" | tracer_name = f"arize_phoenix_tracer_{arize_phoenix_config.project}" | ||||
| logger.info(f"[Arize/Phoenix] Created tracer with name: {tracer_name}") | |||||
| logger.info("[Arize/Phoenix] Created tracer with name: %s", tracer_name) | |||||
| return cast(trace_sdk.Tracer, provider.get_tracer(tracer_name)), processor | return cast(trace_sdk.Tracer, provider.get_tracer(tracer_name)), processor | ||||
| except Exception as e: | except Exception as e: | ||||
| logger.error(f"[Arize/Phoenix] Failed to setup the tracer: {str(e)}", exc_info=True) | |||||
| logger.error("[Arize/Phoenix] Failed to setup the tracer: %s", str(e), exc_info=True) | |||||
| raise | raise | ||||
| self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001") | self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001") | ||||
| def trace(self, trace_info: BaseTraceInfo): | def trace(self, trace_info: BaseTraceInfo): | ||||
| logger.info(f"[Arize/Phoenix] Trace: {trace_info}") | |||||
| logger.info("[Arize/Phoenix] Trace: %s", trace_info) | |||||
| try: | try: | ||||
| if isinstance(trace_info, WorkflowTraceInfo): | if isinstance(trace_info, WorkflowTraceInfo): | ||||
| self.workflow_trace(trace_info) | self.workflow_trace(trace_info) | ||||
| self.generate_name_trace(trace_info) | self.generate_name_trace(trace_info) | ||||
| except Exception as e: | except Exception as e: | ||||
| logger.error(f"[Arize/Phoenix] Error in the trace: {str(e)}", exc_info=True) | |||||
| logger.error("[Arize/Phoenix] Error in the trace: %s", str(e), exc_info=True) | |||||
| raise | raise | ||||
| def workflow_trace(self, trace_info: WorkflowTraceInfo): | def workflow_trace(self, trace_info: WorkflowTraceInfo): | ||||
| trace_id = uuid_to_trace_id(trace_info.message_id) | trace_id = uuid_to_trace_id(trace_info.message_id) | ||||
| tool_span_id = RandomIdGenerator().generate_span_id() | tool_span_id = RandomIdGenerator().generate_span_id() | ||||
| logger.info(f"[Arize/Phoenix] Creating tool trace with trace_id: {trace_id}, span_id: {tool_span_id}") | |||||
| logger.info("[Arize/Phoenix] Creating tool trace with trace_id: %s, span_id: %s", trace_id, tool_span_id) | |||||
| # Create span context with the same trace_id as the parent | # Create span context with the same trace_id as the parent | ||||
| # todo: Create with the appropriate parent span context, so that the tool span is | # todo: Create with the appropriate parent span context, so that the tool span is | ||||
| span.set_attribute("test", "true") | span.set_attribute("test", "true") | ||||
| return True | return True | ||||
| except Exception as e: | except Exception as e: | ||||
| logger.info(f"[Arize/Phoenix] API check failed: {str(e)}", exc_info=True) | |||||
| logger.info("[Arize/Phoenix] API check failed: %s", str(e), exc_info=True) | |||||
| raise ValueError(f"[Arize/Phoenix] API check failed: {str(e)}") | raise ValueError(f"[Arize/Phoenix] API check failed: {str(e)}") | ||||
| def get_project_url(self): | def get_project_url(self): | ||||
| else: | else: | ||||
| return f"{self.arize_phoenix_config.endpoint}/projects/" | return f"{self.arize_phoenix_config.endpoint}/projects/" | ||||
| except Exception as e: | except Exception as e: | ||||
| logger.info(f"[Arize/Phoenix] Get run url failed: {str(e)}", exc_info=True) | |||||
| logger.info("[Arize/Phoenix] Get run url failed: %s", str(e), exc_info=True) | |||||
| raise ValueError(f"[Arize/Phoenix] Get run url failed: {str(e)}") | raise ValueError(f"[Arize/Phoenix] Get run url failed: {str(e)}") | ||||
| def _get_workflow_nodes(self, workflow_run_id: str): | def _get_workflow_nodes(self, workflow_run_id: str): |
| try: | try: | ||||
| return self.langfuse_client.auth_check() | return self.langfuse_client.auth_check() | ||||
| except Exception as e: | except Exception as e: | ||||
| logger.debug(f"LangFuse API check failed: {str(e)}") | |||||
| logger.debug("LangFuse API check failed: %s", str(e)) | |||||
| raise ValueError(f"LangFuse API check failed: {str(e)}") | raise ValueError(f"LangFuse API check failed: {str(e)}") | ||||
| def get_project_key(self): | def get_project_key(self): | ||||
| projects = self.langfuse_client.client.projects.get() | projects = self.langfuse_client.client.projects.get() | ||||
| return projects.data[0].id | return projects.data[0].id | ||||
| except Exception as e: | except Exception as e: | ||||
| logger.debug(f"LangFuse get project key failed: {str(e)}") | |||||
| logger.debug("LangFuse get project key failed: %s", str(e)) | |||||
| raise ValueError(f"LangFuse get project key failed: {str(e)}") | raise ValueError(f"LangFuse get project key failed: {str(e)}") |
| self.langsmith_client.delete_project(project_name=random_project_name) | self.langsmith_client.delete_project(project_name=random_project_name) | ||||
| return True | return True | ||||
| except Exception as e: | except Exception as e: | ||||
| logger.debug(f"LangSmith API check failed: {str(e)}") | |||||
| logger.debug("LangSmith API check failed: %s", str(e)) | |||||
| raise ValueError(f"LangSmith API check failed: {str(e)}") | raise ValueError(f"LangSmith API check failed: {str(e)}") | ||||
| def get_project_url(self): | def get_project_url(self): | ||||
| ) | ) | ||||
| return project_url.split("/r/")[0] | return project_url.split("/r/")[0] | ||||
| except Exception as e: | except Exception as e: | ||||
| logger.debug(f"LangSmith get run url failed: {str(e)}") | |||||
| logger.debug("LangSmith get run url failed: %s", str(e)) | |||||
| raise ValueError(f"LangSmith get run url failed: {str(e)}") | raise ValueError(f"LangSmith get run url failed: {str(e)}") |
| self.opik_client.auth_check() | self.opik_client.auth_check() | ||||
| return True | return True | ||||
| except Exception as e: | except Exception as e: | ||||
| logger.info(f"Opik API check failed: {str(e)}", exc_info=True) | |||||
| logger.info("Opik API check failed: %s", str(e), exc_info=True) | |||||
| raise ValueError(f"Opik API check failed: {str(e)}") | raise ValueError(f"Opik API check failed: {str(e)}") | ||||
| def get_project_url(self): | def get_project_url(self): | ||||
| try: | try: | ||||
| return self.opik_client.get_project_url(project_name=self.project) | return self.opik_client.get_project_url(project_name=self.project) | ||||
| except Exception as e: | except Exception as e: | ||||
| logger.info(f"Opik get run url failed: {str(e)}", exc_info=True) | |||||
| logger.info("Opik get run url failed: %s", str(e), exc_info=True) | |||||
| raise ValueError(f"Opik get run url failed: {str(e)}") | raise ValueError(f"Opik get run url failed: {str(e)}") |
| # create new tracing_instance and update the cache if it absent | # create new tracing_instance and update the cache if it absent | ||||
| tracing_instance = trace_instance(config_class(**decrypt_trace_config)) | tracing_instance = trace_instance(config_class(**decrypt_trace_config)) | ||||
| cls.ops_trace_instances_cache[decrypt_trace_config_key] = tracing_instance | cls.ops_trace_instances_cache[decrypt_trace_config_key] = tracing_instance | ||||
| logging.info(f"new tracing_instance for app_id: {app_id}") | |||||
| logging.info("new tracing_instance for app_id: %s", app_id) | |||||
| return tracing_instance | return tracing_instance | ||||
| @classmethod | @classmethod | ||||
| trace_task.app_id = self.app_id | trace_task.app_id = self.app_id | ||||
| trace_manager_queue.put(trace_task) | trace_manager_queue.put(trace_task) | ||||
| except Exception as e: | except Exception as e: | ||||
| logging.exception(f"Error adding trace task, trace_type {trace_task.trace_type}") | |||||
| logging.exception("Error adding trace task, trace_type %s", trace_task.trace_type) | |||||
| finally: | finally: | ||||
| self.start_timer() | self.start_timer() | ||||
| project_url = f"https://wandb.ai/{self.weave_client._project_id()}" | project_url = f"https://wandb.ai/{self.weave_client._project_id()}" | ||||
| return project_url | return project_url | ||||
| except Exception as e: | except Exception as e: | ||||
| logger.debug(f"Weave get run url failed: {str(e)}") | |||||
| logger.debug("Weave get run url failed: %s", str(e)) | |||||
| raise ValueError(f"Weave get run url failed: {str(e)}") | raise ValueError(f"Weave get run url failed: {str(e)}") | ||||
| def trace(self, trace_info: BaseTraceInfo): | def trace(self, trace_info: BaseTraceInfo): | ||||
| logger.debug(f"Trace info: {trace_info}") | |||||
| logger.debug("Trace info: %s", trace_info) | |||||
| if isinstance(trace_info, WorkflowTraceInfo): | if isinstance(trace_info, WorkflowTraceInfo): | ||||
| self.workflow_trace(trace_info) | self.workflow_trace(trace_info) | ||||
| if isinstance(trace_info, MessageTraceInfo): | if isinstance(trace_info, MessageTraceInfo): | ||||
| print("Weave login successful") | print("Weave login successful") | ||||
| return True | return True | ||||
| except Exception as e: | except Exception as e: | ||||
| logger.debug(f"Weave API check failed: {str(e)}") | |||||
| logger.debug("Weave API check failed: %s", str(e)) | |||||
| raise ValueError(f"Weave API check failed: {str(e)}") | raise ValueError(f"Weave API check failed: {str(e)}") | ||||
| def start_call(self, run_data: WeaveTraceModel, parent_run_id: Optional[str] = None): | def start_call(self, run_data: WeaveTraceModel, parent_run_id: Optional[str] = None): |
| self._config = KeywordTableConfig() | self._config = KeywordTableConfig() | ||||
| def create(self, texts: list[Document], **kwargs) -> BaseKeyword: | def create(self, texts: list[Document], **kwargs) -> BaseKeyword: | ||||
| lock_name = "keyword_indexing_lock_{}".format(self.dataset.id) | |||||
| lock_name = f"keyword_indexing_lock_{self.dataset.id}" | |||||
| with redis_client.lock(lock_name, timeout=600): | with redis_client.lock(lock_name, timeout=600): | ||||
| keyword_table_handler = JiebaKeywordTableHandler() | keyword_table_handler = JiebaKeywordTableHandler() | ||||
| keyword_table = self._get_dataset_keyword_table() | keyword_table = self._get_dataset_keyword_table() | ||||
| return self | return self | ||||
| def add_texts(self, texts: list[Document], **kwargs): | def add_texts(self, texts: list[Document], **kwargs): | ||||
| lock_name = "keyword_indexing_lock_{}".format(self.dataset.id) | |||||
| lock_name = f"keyword_indexing_lock_{self.dataset.id}" | |||||
| with redis_client.lock(lock_name, timeout=600): | with redis_client.lock(lock_name, timeout=600): | ||||
| keyword_table_handler = JiebaKeywordTableHandler() | keyword_table_handler = JiebaKeywordTableHandler() | ||||
| return id in set.union(*keyword_table.values()) | return id in set.union(*keyword_table.values()) | ||||
| def delete_by_ids(self, ids: list[str]) -> None: | def delete_by_ids(self, ids: list[str]) -> None: | ||||
| lock_name = "keyword_indexing_lock_{}".format(self.dataset.id) | |||||
| lock_name = f"keyword_indexing_lock_{self.dataset.id}" | |||||
| with redis_client.lock(lock_name, timeout=600): | with redis_client.lock(lock_name, timeout=600): | ||||
| keyword_table = self._get_dataset_keyword_table() | keyword_table = self._get_dataset_keyword_table() | ||||
| if keyword_table is not None: | if keyword_table is not None: | ||||
| return documents | return documents | ||||
| def delete(self) -> None: | def delete(self) -> None: | ||||
| lock_name = "keyword_indexing_lock_{}".format(self.dataset.id) | |||||
| lock_name = f"keyword_indexing_lock_{self.dataset.id}" | |||||
| with redis_client.lock(lock_name, timeout=600): | with redis_client.lock(lock_name, timeout=600): | ||||
| dataset_keyword_table = self.dataset.dataset_keyword_table | dataset_keyword_table = self.dataset.dataset_keyword_table | ||||
| if dataset_keyword_table: | if dataset_keyword_table: |
| def _create_table(self, dimension: int) -> None: | def _create_table(self, dimension: int) -> None: | ||||
| # Try to grab distributed lock and create table | # Try to grab distributed lock and create table | ||||
| lock_name = "vector_indexing_lock_{}".format(self._collection_name) | |||||
| lock_name = f"vector_indexing_lock_{self._collection_name}" | |||||
| with redis_client.lock(lock_name, timeout=60): | with redis_client.lock(lock_name, timeout=60): | ||||
| table_exist_cache_key = "vector_indexing_{}".format(self._collection_name) | |||||
| table_exist_cache_key = f"vector_indexing_{self._collection_name}" | |||||
| if redis_client.get(table_exist_cache_key): | if redis_client.get(table_exist_cache_key): | ||||
| return | return | ||||
| self.add_texts(texts, embeddings, **kwargs) | self.add_texts(texts, embeddings, **kwargs) | ||||
| def create_collection(self, collection_name: str): | def create_collection(self, collection_name: str): | ||||
| lock_name = "vector_indexing_lock_{}".format(collection_name) | |||||
| lock_name = f"vector_indexing_lock_{collection_name}" | |||||
| with redis_client.lock(lock_name, timeout=20): | with redis_client.lock(lock_name, timeout=20): | ||||
| collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name) | |||||
| collection_exist_cache_key = f"vector_indexing_{self._collection_name}" | |||||
| if redis_client.get(collection_exist_cache_key): | if redis_client.get(collection_exist_cache_key): | ||||
| return | return | ||||
| self._client.get_or_create_collection(collection_name) | self._client.get_or_create_collection(collection_name) |
| self.add_texts(texts, embeddings) | self.add_texts(texts, embeddings) | ||||
| def _create_collection(self, vector_length: int, uuid: str): | def _create_collection(self, vector_length: int, uuid: str): | ||||
| lock_name = "vector_indexing_lock_{}".format(self._collection_name) | |||||
| lock_name = f"vector_indexing_lock_{self._collection_name}" | |||||
| with redis_client.lock(lock_name, timeout=20): | with redis_client.lock(lock_name, timeout=20): | ||||
| collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name) | |||||
| collection_exist_cache_key = f"vector_indexing_{self._collection_name}" | |||||
| if redis_client.get(collection_exist_cache_key): | if redis_client.get(collection_exist_cache_key): | ||||
| return | return | ||||
| if self._collection_exists(self._collection_name): | if self._collection_exists(self._collection_name): | ||||
| try: | try: | ||||
| self._cluster.query(query, named_parameters={"doc_ids": ids}).execute() | self._cluster.query(query, named_parameters={"doc_ids": ids}).execute() | ||||
| except Exception as e: | except Exception as e: | ||||
| logger.exception(f"Failed to delete documents, ids: {ids}") | |||||
| logger.exception("Failed to delete documents, ids: %s", ids) | |||||
| def delete_by_document_id(self, document_id: str): | def delete_by_document_id(self, document_id: str): | ||||
| query = f""" | query = f""" |
| with redis_client.lock(lock_name, timeout=20): | with redis_client.lock(lock_name, timeout=20): | ||||
| collection_exist_cache_key = f"vector_indexing_{self._collection_name}" | collection_exist_cache_key = f"vector_indexing_{self._collection_name}" | ||||
| if redis_client.get(collection_exist_cache_key): | if redis_client.get(collection_exist_cache_key): | ||||
| logger.info(f"Collection {self._collection_name} already exists.") | |||||
| logger.info("Collection %s already exists.", self._collection_name) | |||||
| return | return | ||||
| if not self._client.indices.exists(index=self._collection_name): | if not self._client.indices.exists(index=self._collection_name): |
| with redis_client.lock(lock_name, timeout=20): | with redis_client.lock(lock_name, timeout=20): | ||||
| collection_exist_cache_key = f"vector_indexing_{self._collection_name}" | collection_exist_cache_key = f"vector_indexing_{self._collection_name}" | ||||
| if redis_client.get(collection_exist_cache_key): | if redis_client.get(collection_exist_cache_key): | ||||
| logger.info(f"Collection {self._collection_name} already exists.") | |||||
| logger.info("Collection %s already exists.", self._collection_name) | |||||
| return | return | ||||
| if not self._client.indices.exists(index=self._collection_name): | if not self._client.indices.exists(index=self._collection_name): |
| with redis_client.lock(lock_name, timeout=20): | with redis_client.lock(lock_name, timeout=20): | ||||
| collection_exist_cache_key = f"vector_indexing_{self._collection_name}" | collection_exist_cache_key = f"vector_indexing_{self._collection_name}" | ||||
| if redis_client.get(collection_exist_cache_key): | if redis_client.get(collection_exist_cache_key): | ||||
| logger.info(f"Collection {self._collection_name} already exists.") | |||||
| logger.info("Collection %s already exists.", self._collection_name) | |||||
| return | return | ||||
| if not self._client.indices.exists(index=self._collection_name): | if not self._client.indices.exists(index=self._collection_name): |
| timeout: int = 60, | timeout: int = 60, | ||||
| **kwargs, | **kwargs, | ||||
| ): | ): | ||||
| logger.info(f"Total documents to add: {len(documents)}") | |||||
| logger.info("Total documents to add: %s", len(documents)) | |||||
| uuids = self._get_uuids(documents) | uuids = self._get_uuids(documents) | ||||
| total_docs = len(documents) | total_docs = len(documents) | ||||
| time.sleep(0.5) | time.sleep(0.5) | ||||
| except Exception: | except Exception: | ||||
| logger.exception(f"Failed to process batch {batch_num + 1}") | |||||
| logger.exception("Failed to process batch %s", batch_num + 1) | |||||
| raise | raise | ||||
| def get_ids_by_metadata_field(self, key: str, value: str): | def get_ids_by_metadata_field(self, key: str, value: str): | ||||
| # 1. First check if collection exists | # 1. First check if collection exists | ||||
| if not self._client.indices.exists(index=self._collection_name): | if not self._client.indices.exists(index=self._collection_name): | ||||
| logger.warning(f"Collection {self._collection_name} does not exist") | |||||
| logger.warning("Collection %s does not exist", self._collection_name) | |||||
| return | return | ||||
| # 2. Batch process deletions | # 2. Batch process deletions | ||||
| } | } | ||||
| ) | ) | ||||
| else: | else: | ||||
| logger.warning(f"DELETE BY ID: ID {id} does not exist in the index.") | |||||
| logger.warning("DELETE BY ID: ID %s does not exist in the index.", id) | |||||
| # 3. Perform bulk deletion if there are valid documents to delete | # 3. Perform bulk deletion if there are valid documents to delete | ||||
| if actions: | if actions: | ||||
| doc_id = delete_error.get("_id") | doc_id = delete_error.get("_id") | ||||
| if status == 404: | if status == 404: | ||||
| logger.warning(f"Document not found for deletion: {doc_id}") | |||||
| logger.warning("Document not found for deletion: %s", doc_id) | |||||
| else: | else: | ||||
| logger.exception(f"Error deleting document: {error}") | |||||
| logger.exception("Error deleting document: %s", error) | |||||
| def delete(self) -> None: | def delete(self) -> None: | ||||
| if self._using_ugc: | if self._using_ugc: | ||||
| self._client.indices.delete(index=self._collection_name, params={"timeout": 60}) | self._client.indices.delete(index=self._collection_name, params={"timeout": 60}) | ||||
| logger.info("Delete index success") | logger.info("Delete index success") | ||||
| else: | else: | ||||
| logger.warning(f"Index '{self._collection_name}' does not exist. No deletion performed.") | |||||
| logger.warning("Index '%s' does not exist. No deletion performed.", self._collection_name) | |||||
| def text_exists(self, id: str) -> bool: | def text_exists(self, id: str) -> bool: | ||||
| try: | try: | ||||
| params["routing"] = self._routing # type: ignore | params["routing"] = self._routing # type: ignore | ||||
| response = self._client.search(index=self._collection_name, body=query, params=params) | response = self._client.search(index=self._collection_name, body=query, params=params) | ||||
| except Exception: | except Exception: | ||||
| logger.exception(f"Error executing vector search, query: {query}") | |||||
| logger.exception("Error executing vector search, query: %s", query) | |||||
| raise | raise | ||||
| docs_and_scores = [] | docs_and_scores = [] | ||||
| with redis_client.lock(lock_name, timeout=20): | with redis_client.lock(lock_name, timeout=20): | ||||
| collection_exist_cache_key = f"vector_indexing_{self._collection_name}" | collection_exist_cache_key = f"vector_indexing_{self._collection_name}" | ||||
| if redis_client.get(collection_exist_cache_key): | if redis_client.get(collection_exist_cache_key): | ||||
| logger.info(f"Collection {self._collection_name} already exists.") | |||||
| logger.info("Collection %s already exists.", self._collection_name) | |||||
| return | return | ||||
| if self._client.indices.exists(index=self._collection_name): | if self._client.indices.exists(index=self._collection_name): | ||||
| logger.info(f"{self._collection_name.lower()} already exists.") | |||||
| logger.info("%s already exists.", self._collection_name.lower()) | |||||
| redis_client.set(collection_exist_cache_key, 1, ex=3600) | redis_client.set(collection_exist_cache_key, 1, ex=3600) | ||||
| return | return | ||||
| if len(self.kwargs) == 0 and len(kwargs) != 0: | if len(self.kwargs) == 0 and len(kwargs) != 0: |
| # For standard Milvus installations, check version number | # For standard Milvus installations, check version number | ||||
| return version.parse(milvus_version).base_version >= version.parse("2.5.0").base_version | return version.parse(milvus_version).base_version >= version.parse("2.5.0").base_version | ||||
| except Exception as e: | except Exception as e: | ||||
| logger.warning(f"Failed to check Milvus version: {str(e)}. Disabling hybrid search.") | |||||
| logger.warning("Failed to check Milvus version: %s. Disabling hybrid search.", str(e)) | |||||
| return False | return False | ||||
| def get_type(self) -> str: | def get_type(self) -> str: | ||||
| """ | """ | ||||
| Create a new collection in Milvus with the specified schema and index parameters. | Create a new collection in Milvus with the specified schema and index parameters. | ||||
| """ | """ | ||||
| lock_name = "vector_indexing_lock_{}".format(self._collection_name) | |||||
| lock_name = f"vector_indexing_lock_{self._collection_name}" | |||||
| with redis_client.lock(lock_name, timeout=20): | with redis_client.lock(lock_name, timeout=20): | ||||
| collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name) | |||||
| collection_exist_cache_key = f"vector_indexing_{self._collection_name}" | |||||
| if redis_client.get(collection_exist_cache_key): | if redis_client.get(collection_exist_cache_key): | ||||
| return | return | ||||
| # Grab the existing collection if it exists | # Grab the existing collection if it exists |
| return self.add_texts(documents=texts, embeddings=embeddings, **kwargs) | return self.add_texts(documents=texts, embeddings=embeddings, **kwargs) | ||||
| def _create_collection(self, dimension: int): | def _create_collection(self, dimension: int): | ||||
| logging.info(f"create MyScale collection {self._collection_name} with dimension {dimension}") | |||||
| logging.info("create MyScale collection %s with dimension %s", self._collection_name, dimension) | |||||
| self._client.command(f"CREATE DATABASE IF NOT EXISTS {self._config.database}") | self._client.command(f"CREATE DATABASE IF NOT EXISTS {self._config.database}") | ||||
| fts_params = f"('{self._config.fts_params}')" if self._config.fts_params else "" | fts_params = f"('{self._config.fts_params}')" if self._config.fts_params else "" | ||||
| sql = f""" | sql = f""" | ||||
| for r in self._client.query(sql).named_results() | for r in self._client.query(sql).named_results() | ||||
| ] | ] | ||||
| except Exception as e: | except Exception as e: | ||||
| logging.exception(f"\033[91m\033[1m{type(e)}\033[0m \033[95m{str(e)}\033[0m") # noqa:TRY401 | |||||
| logging.exception("\033[91m\033[1m%s\033[0m \033[95m%s\033[0m", type(e), str(e)) # noqa:TRY401 | |||||
| return [] | return [] | ||||
| def delete(self) -> None: | def delete(self) -> None: |
| logger.debug("Current OceanBase version is %s", ob_version) | logger.debug("Current OceanBase version is %s", ob_version) | ||||
| return version.parse(ob_version).base_version >= version.parse("4.3.5.1").base_version | return version.parse(ob_version).base_version >= version.parse("4.3.5.1").base_version | ||||
| except Exception as e: | except Exception as e: | ||||
| logger.warning(f"Failed to check OceanBase version: {str(e)}. Disabling hybrid search.") | |||||
| logger.warning("Failed to check OceanBase version: %s. Disabling hybrid search.", str(e)) | |||||
| return False | return False | ||||
| def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): | def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): | ||||
| return docs | return docs | ||||
| except Exception as e: | except Exception as e: | ||||
| logger.warning(f"Failed to fulltext search: {str(e)}.") | |||||
| logger.warning("Failed to fulltext search: %s.", str(e)) | |||||
| return [] | return [] | ||||
| def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: | def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: |
| def delete_by_ids(self, ids: list[str]) -> None: | def delete_by_ids(self, ids: list[str]) -> None: | ||||
| index_name = self._collection_name.lower() | index_name = self._collection_name.lower() | ||||
| if not self._client.indices.exists(index=index_name): | if not self._client.indices.exists(index=index_name): | ||||
| logger.warning(f"Index {index_name} does not exist") | |||||
| logger.warning("Index %s does not exist", index_name) | |||||
| return | return | ||||
| # Obtaining All Actual Documents_ID | # Obtaining All Actual Documents_ID | ||||
| if es_ids: | if es_ids: | ||||
| actual_ids.extend(es_ids) | actual_ids.extend(es_ids) | ||||
| else: | else: | ||||
| logger.warning(f"Document with metadata doc_id {doc_id} not found for deletion") | |||||
| logger.warning("Document with metadata doc_id %s not found for deletion", doc_id) | |||||
| if actual_ids: | if actual_ids: | ||||
| actions = [{"_op_type": "delete", "_index": index_name, "_id": es_id} for es_id in actual_ids] | actions = [{"_op_type": "delete", "_index": index_name, "_id": es_id} for es_id in actual_ids] | ||||
| doc_id = delete_error.get("_id") | doc_id = delete_error.get("_id") | ||||
| if status == 404: | if status == 404: | ||||
| logger.warning(f"Document not found for deletion: {doc_id}") | |||||
| logger.warning("Document not found for deletion: %s", doc_id) | |||||
| else: | else: | ||||
| logger.exception(f"Error deleting document: {error}") | |||||
| logger.exception("Error deleting document: %s", error) | |||||
| def delete(self) -> None: | def delete(self) -> None: | ||||
| self._client.indices.delete(index=self._collection_name.lower()) | self._client.indices.delete(index=self._collection_name.lower()) | ||||
| try: | try: | ||||
| response = self._client.search(index=self._collection_name.lower(), body=query) | response = self._client.search(index=self._collection_name.lower(), body=query) | ||||
| except Exception as e: | except Exception as e: | ||||
| logger.exception(f"Error executing vector search, query: {query}") | |||||
| logger.exception("Error executing vector search, query: %s", query) | |||||
| raise | raise | ||||
| docs = [] | docs = [] | ||||
| with redis_client.lock(lock_name, timeout=20): | with redis_client.lock(lock_name, timeout=20): | ||||
| collection_exist_cache_key = f"vector_indexing_{self._collection_name.lower()}" | collection_exist_cache_key = f"vector_indexing_{self._collection_name.lower()}" | ||||
| if redis_client.get(collection_exist_cache_key): | if redis_client.get(collection_exist_cache_key): | ||||
| logger.info(f"Collection {self._collection_name.lower()} already exists.") | |||||
| logger.info("Collection %s already exists.", self._collection_name.lower()) | |||||
| return | return | ||||
| if not self._client.indices.exists(index=self._collection_name.lower()): | if not self._client.indices.exists(index=self._collection_name.lower()): | ||||
| }, | }, | ||||
| } | } | ||||
| logger.info(f"Creating OpenSearch index {self._collection_name.lower()}") | |||||
| logger.info("Creating OpenSearch index %s", self._collection_name.lower()) | |||||
| self._client.indices.create(index=self._collection_name.lower(), body=index_body) | self._client.indices.create(index=self._collection_name.lower(), body=index_body) | ||||
| redis_client.set(collection_exist_cache_key, 1, ex=3600) | redis_client.set(collection_exist_cache_key, 1, ex=3600) |
| self.add_texts(texts, embeddings) | self.add_texts(texts, embeddings) | ||||
| def create_collection(self, dimension: int): | def create_collection(self, dimension: int): | ||||
| lock_name = "vector_indexing_lock_{}".format(self._collection_name) | |||||
| lock_name = f"vector_indexing_lock_{self._collection_name}" | |||||
| with redis_client.lock(lock_name, timeout=20): | with redis_client.lock(lock_name, timeout=20): | ||||
| collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name) | |||||
| collection_exist_cache_key = f"vector_indexing_{self._collection_name}" | |||||
| if redis_client.get(collection_exist_cache_key): | if redis_client.get(collection_exist_cache_key): | ||||
| return | return | ||||
| index_name = f"{self._collection_name}_embedding_index" | index_name = f"{self._collection_name}_embedding_index" |
| cur.execute(f"DELETE FROM {self.table_name} WHERE id IN %s", (tuple(ids),)) | cur.execute(f"DELETE FROM {self.table_name} WHERE id IN %s", (tuple(ids),)) | ||||
| except psycopg2.errors.UndefinedTable: | except psycopg2.errors.UndefinedTable: | ||||
| # table not exists | # table not exists | ||||
| logging.warning(f"Table {self.table_name} not found, skipping delete operation.") | |||||
| logging.warning("Table %s not found, skipping delete operation.", self.table_name) | |||||
| return | return | ||||
| except Exception as e: | except Exception as e: | ||||
| raise e | raise e |
| self.add_texts(texts, embeddings, **kwargs) | self.add_texts(texts, embeddings, **kwargs) | ||||
| def create_collection(self, collection_name: str, vector_size: int): | def create_collection(self, collection_name: str, vector_size: int): | ||||
| lock_name = "vector_indexing_lock_{}".format(collection_name) | |||||
| lock_name = f"vector_indexing_lock_{collection_name}" | |||||
| with redis_client.lock(lock_name, timeout=20): | with redis_client.lock(lock_name, timeout=20): | ||||
| collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name) | |||||
| collection_exist_cache_key = f"vector_indexing_{self._collection_name}" | |||||
| if redis_client.get(collection_exist_cache_key): | if redis_client.get(collection_exist_cache_key): | ||||
| return | return | ||||
| collection_name = collection_name or uuid.uuid4().hex | collection_name = collection_name or uuid.uuid4().hex |
| self.add_texts(texts, embeddings) | self.add_texts(texts, embeddings) | ||||
| def create_collection(self, dimension: int): | def create_collection(self, dimension: int): | ||||
| lock_name = "vector_indexing_lock_{}".format(self._collection_name) | |||||
| lock_name = f"vector_indexing_lock_{self._collection_name}" | |||||
| with redis_client.lock(lock_name, timeout=20): | with redis_client.lock(lock_name, timeout=20): | ||||
| collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name) | |||||
| collection_exist_cache_key = f"vector_indexing_{self._collection_name}" | |||||
| if redis_client.get(collection_exist_cache_key): | if redis_client.get(collection_exist_cache_key): | ||||
| return | return | ||||
| index_name = f"{self._collection_name}_embedding_index" | index_name = f"{self._collection_name}_embedding_index" |
| with redis_client.lock(lock_name, timeout=20): | with redis_client.lock(lock_name, timeout=20): | ||||
| collection_exist_cache_key = f"vector_indexing_{self._collection_name}" | collection_exist_cache_key = f"vector_indexing_{self._collection_name}" | ||||
| if redis_client.get(collection_exist_cache_key): | if redis_client.get(collection_exist_cache_key): | ||||
| logging.info(f"Collection {self._collection_name} already exists.") | |||||
| logging.info("Collection %s already exists.", self._collection_name) | |||||
| return | return | ||||
| self._create_table_if_not_exist() | self._create_table_if_not_exist() |
| def _create_collection(self, dimension: int) -> None: | def _create_collection(self, dimension: int) -> None: | ||||
| self._dimension = dimension | self._dimension = dimension | ||||
| lock_name = "vector_indexing_lock_{}".format(self._collection_name) | |||||
| lock_name = f"vector_indexing_lock_{self._collection_name}" | |||||
| with redis_client.lock(lock_name, timeout=20): | with redis_client.lock(lock_name, timeout=20): | ||||
| collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name) | |||||
| collection_exist_cache_key = f"vector_indexing_{self._collection_name}" | |||||
| if redis_client.get(collection_exist_cache_key): | if redis_client.get(collection_exist_cache_key): | ||||
| return | return | ||||
| self.add_texts(texts, embeddings, **kwargs) | self.add_texts(texts, embeddings, **kwargs) | ||||
| def create_collection(self, collection_name: str, vector_size: int): | def create_collection(self, collection_name: str, vector_size: int): | ||||
| lock_name = "vector_indexing_lock_{}".format(collection_name) | |||||
| lock_name = f"vector_indexing_lock_{collection_name}" | |||||
| with redis_client.lock(lock_name, timeout=20): | with redis_client.lock(lock_name, timeout=20): | ||||
| collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name) | |||||
| collection_exist_cache_key = f"vector_indexing_{self._collection_name}" | |||||
| if redis_client.get(collection_exist_cache_key): | if redis_client.get(collection_exist_cache_key): | ||||
| return | return | ||||
| collection_name = collection_name or uuid.uuid4().hex | collection_name = collection_name or uuid.uuid4().hex |
| def _create_collection(self, dimension: int): | def _create_collection(self, dimension: int): | ||||
| logger.info("_create_collection, collection_name " + self._collection_name) | logger.info("_create_collection, collection_name " + self._collection_name) | ||||
| lock_name = "vector_indexing_lock_{}".format(self._collection_name) | |||||
| lock_name = f"vector_indexing_lock_{self._collection_name}" | |||||
| with redis_client.lock(lock_name, timeout=20): | with redis_client.lock(lock_name, timeout=20): | ||||
| collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name) | |||||
| collection_exist_cache_key = f"vector_indexing_{self._collection_name}" | |||||
| if redis_client.get(collection_exist_cache_key): | if redis_client.get(collection_exist_cache_key): | ||||
| return | return | ||||
| tidb_dist_func = self._get_distance_func() | tidb_dist_func = self._get_distance_func() | ||||
| query_vector_str = ", ".join(format(x) for x in query_vector) | query_vector_str = ", ".join(format(x) for x in query_vector) | ||||
| query_vector_str = "[" + query_vector_str + "]" | query_vector_str = "[" + query_vector_str + "]" | ||||
| logger.debug( | logger.debug( | ||||
| f"_collection_name: {self._collection_name}, score_threshold: {score_threshold}, distance: {distance}" | |||||
| "_collection_name: %s, score_threshold: %s, distance: %s", self._collection_name, score_threshold, distance | |||||
| ) | ) | ||||
| docs = [] | docs = [] |
| def create(self, texts: Optional[list] = None, **kwargs): | def create(self, texts: Optional[list] = None, **kwargs): | ||||
| if texts: | if texts: | ||||
| start = time.time() | start = time.time() | ||||
| logger.info(f"start embedding {len(texts)} texts {start}") | |||||
| logger.info("start embedding %s texts %s", len(texts), start) | |||||
| batch_size = 1000 | batch_size = 1000 | ||||
| total_batches = len(texts) + batch_size - 1 | total_batches = len(texts) + batch_size - 1 | ||||
| for i in range(0, len(texts), batch_size): | for i in range(0, len(texts), batch_size): | ||||
| batch = texts[i : i + batch_size] | batch = texts[i : i + batch_size] | ||||
| batch_start = time.time() | batch_start = time.time() | ||||
| logger.info(f"Processing batch {i // batch_size + 1}/{total_batches} ({len(batch)} texts)") | |||||
| logger.info("Processing batch %s/%s (%s texts)", i // batch_size + 1, total_batches, len(batch)) | |||||
| batch_embeddings = self._embeddings.embed_documents([document.page_content for document in batch]) | batch_embeddings = self._embeddings.embed_documents([document.page_content for document in batch]) | ||||
| logger.info( | logger.info( | ||||
| f"Embedding batch {i // batch_size + 1}/{total_batches} took {time.time() - batch_start:.3f}s" | |||||
| "Embedding batch %s/%s took %s s", i // batch_size + 1, total_batches, time.time() - batch_start | |||||
| ) | ) | ||||
| self._vector_processor.create(texts=batch, embeddings=batch_embeddings, **kwargs) | self._vector_processor.create(texts=batch, embeddings=batch_embeddings, **kwargs) | ||||
| logger.info(f"Embedding {len(texts)} texts took {time.time() - start:.3f}s") | |||||
| logger.info("Embedding %s texts took %s s", len(texts), time.time() - start) | |||||
| def add_texts(self, documents: list[Document], **kwargs): | def add_texts(self, documents: list[Document], **kwargs): | ||||
| if kwargs.get("duplicate_check", False): | if kwargs.get("duplicate_check", False): | ||||
| self._vector_processor.delete() | self._vector_processor.delete() | ||||
| # delete collection redis cache | # delete collection redis cache | ||||
| if self._vector_processor.collection_name: | if self._vector_processor.collection_name: | ||||
| collection_exist_cache_key = "vector_indexing_{}".format(self._vector_processor.collection_name) | |||||
| collection_exist_cache_key = f"vector_indexing_{self._vector_processor.collection_name}" | |||||
| redis_client.delete(collection_exist_cache_key) | redis_client.delete(collection_exist_cache_key) | ||||
| def _get_embeddings(self) -> Embeddings: | def _get_embeddings(self) -> Embeddings: |
| self.add_texts(texts, embeddings) | self.add_texts(texts, embeddings) | ||||
| def _create_collection(self): | def _create_collection(self): | ||||
| lock_name = "vector_indexing_lock_{}".format(self._collection_name) | |||||
| lock_name = f"vector_indexing_lock_{self._collection_name}" | |||||
| with redis_client.lock(lock_name, timeout=20): | with redis_client.lock(lock_name, timeout=20): | ||||
| collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name) | |||||
| collection_exist_cache_key = f"vector_indexing_{self._collection_name}" | |||||
| if redis_client.get(collection_exist_cache_key): | if redis_client.get(collection_exist_cache_key): | ||||
| return | return | ||||
| schema = self._default_schema(self._collection_name) | schema = self._default_schema(self._collection_name) |
| # stackoverflow best way: https://stackoverflow.com/questions/20319813/how-to-check-list-containing-nan | # stackoverflow best way: https://stackoverflow.com/questions/20319813/how-to-check-list-containing-nan | ||||
| if np.isnan(normalized_embedding).any(): | if np.isnan(normalized_embedding).any(): | ||||
| # for issue #11827 float values are not json compliant | # for issue #11827 float values are not json compliant | ||||
| logger.warning(f"Normalized embedding is nan: {normalized_embedding}") | |||||
| logger.warning("Normalized embedding is nan: %s", normalized_embedding) | |||||
| continue | continue | ||||
| embedding_queue_embeddings.append(normalized_embedding) | embedding_queue_embeddings.append(normalized_embedding) | ||||
| except IntegrityError: | except IntegrityError: | ||||
| raise ValueError("Normalized embedding is nan please try again") | raise ValueError("Normalized embedding is nan please try again") | ||||
| except Exception as ex: | except Exception as ex: | ||||
| if dify_config.DEBUG: | if dify_config.DEBUG: | ||||
| logging.exception(f"Failed to embed query text '{text[:10]}...({len(text)} chars)'") | |||||
| logging.exception("Failed to embed query text '%s...(%s chars)'", text[:10], len(text)) | |||||
| raise ex | raise ex | ||||
| try: | try: | ||||
| redis_client.setex(embedding_cache_key, 600, encoded_str) | redis_client.setex(embedding_cache_key, 600, encoded_str) | ||||
| except Exception as ex: | except Exception as ex: | ||||
| if dify_config.DEBUG: | if dify_config.DEBUG: | ||||
| logging.exception(f"Failed to add embedding to redis for the text '{text[:10]}...({len(text)} chars)'") | |||||
| logging.exception( | |||||
| "Failed to add embedding to redis for the text '%s...(%s chars)'", text[:10], len(text) | |||||
| ) | |||||
| raise ex | raise ex | ||||
| return embedding_results # type: ignore | return embedding_results # type: ignore |
| if total + _len + (separator_len if len(current_doc) > 0 else 0) > self._chunk_size: | if total + _len + (separator_len if len(current_doc) > 0 else 0) > self._chunk_size: | ||||
| if total > self._chunk_size: | if total > self._chunk_size: | ||||
| logger.warning( | logger.warning( | ||||
| f"Created a chunk of size {total}, which is longer than the specified {self._chunk_size}" | |||||
| "Created a chunk of size %s, which is longer than the specified %s", total, self._chunk_size | |||||
| ) | ) | ||||
| if len(current_doc) > 0: | if len(current_doc) > 0: | ||||
| doc = self._join_docs(current_doc, separator) | doc = self._join_docs(current_doc, separator) |
| RepositoryImportError: If the configured repository cannot be created | RepositoryImportError: If the configured repository cannot be created | ||||
| """ | """ | ||||
| class_path = dify_config.CORE_WORKFLOW_EXECUTION_REPOSITORY | class_path = dify_config.CORE_WORKFLOW_EXECUTION_REPOSITORY | ||||
| logger.debug(f"Creating WorkflowExecutionRepository from: {class_path}") | |||||
| logger.debug("Creating WorkflowExecutionRepository from: %s", class_path) | |||||
| try: | try: | ||||
| repository_class = cls._import_class(class_path) | repository_class = cls._import_class(class_path) | ||||
| RepositoryImportError: If the configured repository cannot be created | RepositoryImportError: If the configured repository cannot be created | ||||
| """ | """ | ||||
| class_path = dify_config.CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY | class_path = dify_config.CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY | ||||
| logger.debug(f"Creating WorkflowNodeExecutionRepository from: {class_path}") | |||||
| logger.debug("Creating WorkflowNodeExecutionRepository from: %s", class_path) | |||||
| try: | try: | ||||
| repository_class = cls._import_class(class_path) | repository_class = cls._import_class(class_path) |
| session.commit() | session.commit() | ||||
| # Update the in-memory cache for faster subsequent lookups | # Update the in-memory cache for faster subsequent lookups | ||||
| logger.debug(f"Updating cache for execution_id: {db_model.id}") | |||||
| logger.debug("Updating cache for execution_id: %s", db_model.id) | |||||
| self._execution_cache[db_model.id] = db_model | self._execution_cache[db_model.id] = db_model |
| # Update the in-memory cache for faster subsequent lookups | # Update the in-memory cache for faster subsequent lookups | ||||
| # Only cache if we have a node_execution_id to use as the cache key | # Only cache if we have a node_execution_id to use as the cache key | ||||
| if db_model.node_execution_id: | if db_model.node_execution_id: | ||||
| logger.debug(f"Updating cache for node_execution_id: {db_model.node_execution_id}") | |||||
| logger.debug("Updating cache for node_execution_id: %s", db_model.node_execution_id) | |||||
| self._node_execution_cache[db_model.node_execution_id] = db_model | self._node_execution_cache[db_model.node_execution_id] = db_model | ||||
| def get_db_models_by_workflow_run( | def get_db_models_by_workflow_run( |
| ) | ) | ||||
| except Exception as e: | except Exception as e: | ||||
| builtin_provider = None | builtin_provider = None | ||||
| logger.info(f"Error getting builtin provider {credential_id}:{e}", exc_info=True) | |||||
| logger.info("Error getting builtin provider %s:%s", credential_id, e, exc_info=True) | |||||
| # if the provider has been deleted, raise an error | # if the provider has been deleted, raise an error | ||||
| if builtin_provider is None: | if builtin_provider is None: | ||||
| raise ToolProviderNotFoundError(f"provider has been deleted: {credential_id}") | raise ToolProviderNotFoundError(f"provider has been deleted: {credential_id}") | ||||
| yield provider | yield provider | ||||
| except Exception: | except Exception: | ||||
| logger.exception(f"load builtin provider {provider_path}") | |||||
| logger.exception("load builtin provider %s", provider_path) | |||||
| continue | continue | ||||
| # set builtin providers loaded | # set builtin providers loaded | ||||
| cls._builtin_providers_loaded = True | cls._builtin_providers_loaded = True |
| main_content_type = mimetypes.guess_type(filename)[0] | main_content_type = mimetypes.guess_type(filename)[0] | ||||
| if main_content_type not in supported_content_types: | if main_content_type not in supported_content_types: | ||||
| return "Unsupported content-type [{}] of URL.".format(main_content_type) | |||||
| return f"Unsupported content-type [{main_content_type}] of URL." | |||||
| if main_content_type in extract_processor.SUPPORT_URL_CONTENT_TYPES: | if main_content_type in extract_processor.SUPPORT_URL_CONTENT_TYPES: | ||||
| return cast(str, ExtractProcessor.load_from_url(url, return_text=True)) | return cast(str, ExtractProcessor.load_from_url(url, return_text=True)) | ||||
| response = scraper.get(url, headers=headers, follow_redirects=True, timeout=(120, 300)) # type: ignore | response = scraper.get(url, headers=headers, follow_redirects=True, timeout=(120, 300)) # type: ignore | ||||
| if response.status_code != 200: | if response.status_code != 200: | ||||
| return "URL returned status code {}.".format(response.status_code) | |||||
| return f"URL returned status code {response.status_code}." | |||||
| # Detect encoding using chardet | # Detect encoding using chardet | ||||
| detected_encoding = chardet.detect(response.content) | detected_encoding = chardet.detect(response.content) |
| files.append(file_dict) | files.append(file_dict) | ||||
| except Exception: | except Exception: | ||||
| logger.exception(f"Failed to transform file {file}") | |||||
| logger.exception("Failed to transform file %s", file) | |||||
| else: | else: | ||||
| parameters_result[parameter.name] = tool_parameters.get(parameter.name) | parameters_result[parameter.name] = tool_parameters.get(parameter.name) | ||||
| while True: | while True: | ||||
| # max steps reached | # max steps reached | ||||
| if self.graph_runtime_state.node_run_steps > self.max_execution_steps: | if self.graph_runtime_state.node_run_steps > self.max_execution_steps: | ||||
| raise GraphRunFailedError("Max steps {} reached.".format(self.max_execution_steps)) | |||||
| raise GraphRunFailedError(f"Max steps {self.max_execution_steps} reached.") | |||||
| # or max execution time reached | # or max execution time reached | ||||
| if self._is_timed_out( | if self._is_timed_out( | ||||
| start_at=self.graph_runtime_state.start_at, max_execution_time=self.max_execution_time | start_at=self.graph_runtime_state.start_at, max_execution_time=self.max_execution_time | ||||
| ): | ): | ||||
| raise GraphRunFailedError("Max execution time {}s reached.".format(self.max_execution_time)) | |||||
| raise GraphRunFailedError(f"Max execution time {self.max_execution_time}s reached.") | |||||
| # init route node state | # init route node state | ||||
| route_node_state = self.graph_runtime_state.node_run_state.create_node_state(node_id=next_node_id) | route_node_state = self.graph_runtime_state.node_run_state.create_node_state(node_id=next_node_id) | ||||
| edge = cast(GraphEdge, sub_edge_mappings[0]) | edge = cast(GraphEdge, sub_edge_mappings[0]) | ||||
| if edge.run_condition is None: | if edge.run_condition is None: | ||||
| logger.warning(f"Edge {edge.target_node_id} run condition is None") | |||||
| logger.warning("Edge %s run condition is None", edge.target_node_id) | |||||
| continue | continue | ||||
| result = ConditionManager.get_condition_handler( | result = ConditionManager.get_condition_handler( | ||||
| ) | ) | ||||
| return | return | ||||
| except Exception as e: | except Exception as e: | ||||
| logger.exception(f"Node {node.title} run failed") | |||||
| logger.exception("Node %s run failed", node.title) | |||||
| raise e | raise e | ||||
| def _append_variables_recursively(self, node_id: str, variable_key_list: list[str], variable_value: VariableValue): | def _append_variables_recursively(self, node_id: str, variable_key_list: list[str], variable_value: VariableValue): |
| reachable_node_ids: list[str] = [] | reachable_node_ids: list[str] = [] | ||||
| unreachable_first_node_ids: list[str] = [] | unreachable_first_node_ids: list[str] = [] | ||||
| if finished_node_id not in self.graph.edge_mapping: | if finished_node_id not in self.graph.edge_mapping: | ||||
| logger.warning(f"node {finished_node_id} has no edge mapping") | |||||
| logger.warning("node %s has no edge mapping", finished_node_id) | |||||
| return | return | ||||
| for edge in self.graph.edge_mapping[finished_node_id]: | for edge in self.graph.edge_mapping[finished_node_id]: | ||||
| if ( | if ( |
| try: | try: | ||||
| result = self._run() | result = self._run() | ||||
| except Exception as e: | except Exception as e: | ||||
| logger.exception(f"Node {self.node_id} failed to run") | |||||
| logger.exception("Node %s failed to run", self.node_id) | |||||
| result = NodeRunResult( | result = NodeRunResult( | ||||
| status=WorkflowNodeExecutionStatus.FAILED, | status=WorkflowNodeExecutionStatus.FAILED, | ||||
| error=str(e), | error=str(e), |
| text.append(markdown_table) | text.append(markdown_table) | ||||
| except Exception as e: | except Exception as e: | ||||
| logger.warning(f"Failed to extract table from DOC: {e}") | |||||
| logger.warning("Failed to extract table from DOC: %s", e) | |||||
| continue | continue | ||||
| return "\n".join(text) | return "\n".join(text) |
| }, | }, | ||||
| ) | ) | ||||
| except HttpRequestNodeError as e: | except HttpRequestNodeError as e: | ||||
| logger.warning(f"http request node {self.node_id} failed to run: {e}") | |||||
| logger.warning("http request node %s failed to run: %s", self.node_id, e) | |||||
| return NodeRunResult( | return NodeRunResult( | ||||
| status=WorkflowNodeExecutionStatus.FAILED, | status=WorkflowNodeExecutionStatus.FAILED, | ||||
| error=str(e), | error=str(e), |
| var_mapping: dict[str, list[str]] = {} | var_mapping: dict[str, list[str]] = {} | ||||
| for case in typed_node_data.cases or []: | for case in typed_node_data.cases or []: | ||||
| for condition in case.conditions: | for condition in case.conditions: | ||||
| key = "{}.#{}#".format(node_id, ".".join(condition.variable_selector)) | |||||
| key = f"{node_id}.#{'.'.join(condition.variable_selector)}#" | |||||
| var_mapping[key] = condition.variable_selector | var_mapping[key] = condition.variable_selector | ||||
| return var_mapping | return var_mapping |
| ) | ) | ||||
| except IterationNodeError as e: | except IterationNodeError as e: | ||||
| logger.warning(f"Iteration run failed:{str(e)}") | |||||
| logger.warning("Iteration run failed:%s", str(e)) | |||||
| yield IterationRunFailedEvent( | yield IterationRunFailedEvent( | ||||
| iteration_id=self.id, | iteration_id=self.id, | ||||
| iteration_node_id=self.node_id, | iteration_node_id=self.node_id, |
| return cast(dict, json.loads(json_str)) | return cast(dict, json.loads(json_str)) | ||||
| except Exception: | except Exception: | ||||
| pass | pass | ||||
| logger.info(f"extra error: {result}") | |||||
| logger.info("extra error: %s", result) | |||||
| return None | return None | ||||
| def _extract_json_from_tool_call(self, tool_call: AssistantPromptMessage.ToolCall) -> Optional[dict]: | def _extract_json_from_tool_call(self, tool_call: AssistantPromptMessage.ToolCall) -> Optional[dict]: | ||||
| return cast(dict, json.loads(json_str)) | return cast(dict, json.loads(json_str)) | ||||
| except Exception: | except Exception: | ||||
| pass | pass | ||||
| logger.info(f"extra error: {result}") | |||||
| logger.info("extra error: %s", result) | |||||
| return None | return None | ||||
| def _generate_default_result(self, data: ParameterExtractorNodeData) -> dict: | def _generate_default_result(self, data: ParameterExtractorNodeData) -> dict: |
| # check call depth | # check call depth | ||||
| workflow_call_max_depth = dify_config.WORKFLOW_CALL_MAX_DEPTH | workflow_call_max_depth = dify_config.WORKFLOW_CALL_MAX_DEPTH | ||||
| if call_depth > workflow_call_max_depth: | if call_depth > workflow_call_max_depth: | ||||
| raise ValueError("Max workflow call depth {} reached.".format(workflow_call_max_depth)) | |||||
| raise ValueError(f"Max workflow call depth {workflow_call_max_depth} reached.") | |||||
| # init workflow run state | # init workflow run state | ||||
| graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) | graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) | ||||
| # run node | # run node | ||||
| generator = node.run() | generator = node.run() | ||||
| except Exception as e: | except Exception as e: | ||||
| logger.exception(f"error while running node, {workflow.id=}, {node.id=}, {node.type_=}, {node.version()=}") | |||||
| logger.exception( | |||||
| "error while running node, workflow_id=%s, node_id=%s, node_type=%s, node_version=%s", | |||||
| workflow.id, | |||||
| node.id, | |||||
| node.type_, | |||||
| node.version(), | |||||
| ) | |||||
| raise WorkflowNodeRunFailedError(node=node, err_msg=str(e)) | raise WorkflowNodeRunFailedError(node=node, err_msg=str(e)) | ||||
| return node, generator | return node, generator | ||||
| return node, generator | return node, generator | ||||
| except Exception as e: | except Exception as e: | ||||
| logger.exception(f"error while running node, {node.id=}, {node.type_=}, {node.version()=}") | |||||
| logger.exception( | |||||
| "error while running node, node_id=%s, node_type=%s, node_version=%s", | |||||
| node.id, | |||||
| node.type_, | |||||
| node.version(), | |||||
| ) | |||||
| raise WorkflowNodeRunFailedError(node=node, err_msg=str(e)) | raise WorkflowNodeRunFailedError(node=node, err_msg=str(e)) | ||||
| @staticmethod | @staticmethod |
| documents = [] | documents = [] | ||||
| start_at = time.perf_counter() | start_at = time.perf_counter() | ||||
| for document_id in document_ids: | for document_id in document_ids: | ||||
| logging.info(click.style("Start process document: {}".format(document_id), fg="green")) | |||||
| logging.info(click.style(f"Start process document: {document_id}", fg="green")) | |||||
| document = ( | document = ( | ||||
| db.session.query(Document) | db.session.query(Document) | ||||
| indexing_runner = IndexingRunner() | indexing_runner = IndexingRunner() | ||||
| indexing_runner.run(documents) | indexing_runner.run(documents) | ||||
| end_at = time.perf_counter() | end_at = time.perf_counter() | ||||
| logging.info(click.style("Processed dataset: {} latency: {}".format(dataset_id, end_at - start_at), fg="green")) | |||||
| logging.info(click.style(f"Processed dataset: {dataset_id} latency: {end_at - start_at}", fg="green")) | |||||
| except DocumentIsPausedError as ex: | except DocumentIsPausedError as ex: | ||||
| logging.info(click.style(str(ex), fg="yellow")) | logging.info(click.style(str(ex), fg="yellow")) | ||||
| except Exception: | except Exception: |
| duration = time_module.perf_counter() - start_time | duration = time_module.perf_counter() - start_time | ||||
| logger.info( | logger.info( | ||||
| f"Provider updates completed successfully. " | |||||
| f"Updates: {len(updates_to_perform)}, Duration: {duration:.3f}s, " | |||||
| f"Tenant: {tenant_id}, Provider: {provider_name}" | |||||
| "Provider updates completed successfully. Updates: %s, Duration: %s s, Tenant: %s, Provider: %s", | |||||
| len(updates_to_perform), | |||||
| duration, | |||||
| tenant_id, | |||||
| provider_name, | |||||
| ) | ) | ||||
| except Exception as e: | except Exception as e: | ||||
| duration = time_module.perf_counter() - start_time | duration = time_module.perf_counter() - start_time | ||||
| logger.exception( | logger.exception( | ||||
| f"Provider updates failed after {duration:.3f}s. " | |||||
| f"Updates: {len(updates_to_perform)}, Tenant: {tenant_id}, " | |||||
| f"Provider: {provider_name}" | |||||
| "Provider updates failed after %s s. Updates: %s, Tenant: %s, Provider: %s", | |||||
| duration, | |||||
| len(updates_to_perform), | |||||
| tenant_id, | |||||
| provider_name, | |||||
| ) | ) | ||||
| raise | raise | ||||
| rows_affected = result.rowcount | rows_affected = result.rowcount | ||||
| logger.debug( | logger.debug( | ||||
| f"Provider update ({description}): {rows_affected} rows affected. " | |||||
| f"Filters: {filters.model_dump()}, Values: {update_values}" | |||||
| "Provider update (%s): %s rows affected. Filters: %s, Values: %s", | |||||
| description, | |||||
| rows_affected, | |||||
| filters.model_dump(), | |||||
| update_values, | |||||
| ) | ) | ||||
| # If no rows were affected for quota updates, log a warning | # If no rows were affected for quota updates, log a warning | ||||
| if rows_affected == 0 and description == "quota_deduction_update": | if rows_affected == 0 and description == "quota_deduction_update": | ||||
| logger.warning( | logger.warning( | ||||
| f"No Provider rows updated for quota deduction. " | |||||
| f"This may indicate quota limit exceeded or provider not found. " | |||||
| f"Filters: {filters.model_dump()}" | |||||
| "No Provider rows updated for quota deduction. " | |||||
| "This may indicate quota limit exceeded or provider not found. " | |||||
| "Filters: %s", | |||||
| filters.model_dump(), | |||||
| ) | ) | ||||
| logger.debug(f"Successfully processed {len(updates_to_perform)} Provider updates") | |||||
| logger.debug("Successfully processed %s Provider updates", len(updates_to_perform)) |
| sendgrid_api_key=dify_config.SENDGRID_API_KEY, _from=dify_config.MAIL_DEFAULT_SEND_FROM or "" | sendgrid_api_key=dify_config.SENDGRID_API_KEY, _from=dify_config.MAIL_DEFAULT_SEND_FROM or "" | ||||
| ) | ) | ||||
| case _: | case _: | ||||
| raise ValueError("Unsupported mail type {}".format(mail_type)) | |||||
| raise ValueError(f"Unsupported mail type {mail_type}") | |||||
| def send(self, to: str, subject: str, html: str, from_: Optional[str] = None): | def send(self, to: str, subject: str, html: str, from_: Optional[str] = None): | ||||
| if not self._client: | if not self._client: |
| try: | try: | ||||
| return func(*args, **kwargs) | return func(*args, **kwargs) | ||||
| except RedisError as e: | except RedisError as e: | ||||
| logger.warning(f"Redis operation failed in {func.__name__}: {str(e)}", exc_info=True) | |||||
| logger.warning("Redis operation failed in %s: %s", func.__name__, str(e), exc_info=True) | |||||
| return default_return | return default_return | ||||
| return wrapper | return wrapper |
| if self.account_key == "managedidentity": | if self.account_key == "managedidentity": | ||||
| return BlobServiceClient(account_url=self.account_url, credential=self.credential) # type: ignore | return BlobServiceClient(account_url=self.account_url, credential=self.credential) # type: ignore | ||||
| cache_key = "azure_blob_sas_token_{}_{}".format(self.account_name, self.account_key) | |||||
| cache_key = f"azure_blob_sas_token_{self.account_name}_{self.account_key}" | |||||
| cache_result = redis_client.get(cache_key) | cache_result = redis_client.get(cache_key) | ||||
| if cache_result is not None: | if cache_result is not None: | ||||
| sas_token = cache_result.decode("utf-8") | sas_token = cache_result.decode("utf-8") |
| Path(root).mkdir(parents=True, exist_ok=True) | Path(root).mkdir(parents=True, exist_ok=True) | ||||
| self.op = opendal.Operator(scheme=scheme, **kwargs) # type: ignore | self.op = opendal.Operator(scheme=scheme, **kwargs) # type: ignore | ||||
| logger.debug(f"opendal operator created with scheme {scheme}") | |||||
| logger.debug("opendal operator created with scheme %s", scheme) | |||||
| retry_layer = opendal.layers.RetryLayer(max_times=3, factor=2.0, jitter=True) | retry_layer = opendal.layers.RetryLayer(max_times=3, factor=2.0, jitter=True) | ||||
| self.op = self.op.layer(retry_layer) | self.op = self.op.layer(retry_layer) | ||||
| logger.debug("added retry layer to opendal operator") | logger.debug("added retry layer to opendal operator") | ||||
| def save(self, filename: str, data: bytes) -> None: | def save(self, filename: str, data: bytes) -> None: | ||||
| self.op.write(path=filename, bs=data) | self.op.write(path=filename, bs=data) | ||||
| logger.debug(f"file {filename} saved") | |||||
| logger.debug("file %s saved", filename) | |||||
| def load_once(self, filename: str) -> bytes: | def load_once(self, filename: str) -> bytes: | ||||
| if not self.exists(filename): | if not self.exists(filename): | ||||
| raise FileNotFoundError("File not found") | raise FileNotFoundError("File not found") | ||||
| content: bytes = self.op.read(path=filename) | content: bytes = self.op.read(path=filename) | ||||
| logger.debug(f"file {filename} loaded") | |||||
| logger.debug("file %s loaded", filename) | |||||
| return content | return content | ||||
| def load_stream(self, filename: str) -> Generator: | def load_stream(self, filename: str) -> Generator: | ||||
| file = self.op.open(path=filename, mode="rb") | file = self.op.open(path=filename, mode="rb") | ||||
| while chunk := file.read(batch_size): | while chunk := file.read(batch_size): | ||||
| yield chunk | yield chunk | ||||
| logger.debug(f"file {filename} loaded as stream") | |||||
| logger.debug("file %s loaded as stream", filename) | |||||
| def download(self, filename: str, target_filepath: str): | def download(self, filename: str, target_filepath: str): | ||||
| if not self.exists(filename): | if not self.exists(filename): | ||||
| with Path(target_filepath).open("wb") as f: | with Path(target_filepath).open("wb") as f: | ||||
| f.write(self.op.read(path=filename)) | f.write(self.op.read(path=filename)) | ||||
| logger.debug(f"file {filename} downloaded to {target_filepath}") | |||||
| logger.debug("file %s downloaded to %s", filename, target_filepath) | |||||
| def exists(self, filename: str) -> bool: | def exists(self, filename: str) -> bool: | ||||
| res: bool = self.op.exists(path=filename) | res: bool = self.op.exists(path=filename) | ||||
| def delete(self, filename: str): | def delete(self, filename: str): | ||||
| if self.exists(filename): | if self.exists(filename): | ||||
| self.op.delete(path=filename) | self.op.delete(path=filename) | ||||
| logger.debug(f"file {filename} deleted") | |||||
| logger.debug("file %s deleted", filename) | |||||
| return | return | ||||
| logger.debug(f"file {filename} not found, skip delete") | |||||
| logger.debug("file %s not found, skip delete", filename) | |||||
| def scan(self, path: str, files: bool = True, directories: bool = False) -> list[str]: | def scan(self, path: str, files: bool = True, directories: bool = False) -> list[str]: | ||||
| if not self.exists(path): | if not self.exists(path): | ||||
| all_files = self.op.scan(path=path) | all_files = self.op.scan(path=path) | ||||
| if files and directories: | if files and directories: | ||||
| logger.debug(f"files and directories on {path} scanned") | |||||
| logger.debug("files and directories on %s scanned", path) | |||||
| return [f.path for f in all_files] | return [f.path for f in all_files] | ||||
| if files: | if files: | ||||
| logger.debug(f"files on {path} scanned") | |||||
| logger.debug("files on %s scanned", path) | |||||
| return [f.path for f in all_files if not f.path.endswith("/")] | return [f.path for f in all_files if not f.path.endswith("/")] | ||||
| elif directories: | elif directories: | ||||
| logger.debug(f"directories on {path} scanned") | |||||
| logger.debug("directories on %s scanned", path) | |||||
| return [f.path for f in all_files if f.path.endswith("/")] | return [f.path for f in all_files if f.path.endswith("/")] | ||||
| else: | else: | ||||
| raise ValueError("At least one of files or directories must be True") | raise ValueError("At least one of files or directories must be True") |
| def load_once(self, filename: str) -> bytes: | def load_once(self, filename: str) -> bytes: | ||||
| data = self.client.get_object(bucket=self.bucket_name, key=filename).read() | data = self.client.get_object(bucket=self.bucket_name, key=filename).read() | ||||
| if not isinstance(data, bytes): | if not isinstance(data, bytes): | ||||
| raise TypeError("Expected bytes, got {}".format(type(data).__name__)) | |||||
| raise TypeError(f"Expected bytes, got {type(data).__name__}") | |||||
| return data | return data | ||||
| def load_stream(self, filename: str) -> Generator: | def load_stream(self, filename: str) -> Generator: |
| if re.match(pattern, email) is not None: | if re.match(pattern, email) is not None: | ||||
| return email | return email | ||||
| error = "{email} is not a valid email.".format(email=email) | |||||
| error = f"{email} is not a valid email." | |||||
| raise ValueError(error) | raise ValueError(error) | ||||
| uuid_obj = uuid.UUID(value) | uuid_obj = uuid.UUID(value) | ||||
| return str(uuid_obj) | return str(uuid_obj) | ||||
| except ValueError: | except ValueError: | ||||
| error = "{value} is not a valid uuid.".format(value=value) | |||||
| error = f"{value} is not a valid uuid." | |||||
| raise ValueError(error) | raise ValueError(error) | ||||
| raise ValueError | raise ValueError | ||||
| return int_timestamp | return int_timestamp | ||||
| except ValueError: | except ValueError: | ||||
| error = "{timestamp} is not a valid timestamp.".format(timestamp=timestamp) | |||||
| error = f"{timestamp} is not a valid timestamp." | |||||
| raise ValueError(error) | raise ValueError(error) | ||||
| try: | try: | ||||
| return float(value) | return float(value) | ||||
| except (TypeError, ValueError): | except (TypeError, ValueError): | ||||
| raise ValueError("{} is not a valid float".format(value)) | |||||
| raise ValueError(f"{value} is not a valid float") | |||||
| def timezone(timezone_string): | def timezone(timezone_string): | ||||
| if timezone_string and timezone_string in available_timezones(): | if timezone_string and timezone_string in available_timezones(): | ||||
| return timezone_string | return timezone_string | ||||
| error = "{timezone_string} is not a valid timezone.".format(timezone_string=timezone_string) | |||||
| error = f"{timezone_string} is not a valid timezone." | |||||
| raise ValueError(error) | raise ValueError(error) | ||||
| key = cls._get_token_key(token, token_type) | key = cls._get_token_key(token, token_type) | ||||
| token_data_json = redis_client.get(key) | token_data_json = redis_client.get(key) | ||||
| if token_data_json is None: | if token_data_json is None: | ||||
| logging.warning(f"{token_type} token {token} not found with key {key}") | |||||
| logging.warning("%s token %s not found with key %s", token_type, token, key) | |||||
| return None | return None | ||||
| token_data: Optional[dict[str, Any]] = json.loads(token_data_json) | token_data: Optional[dict[str, Any]] = json.loads(token_data_json) | ||||
| return token_data | return token_data |
| def get_decrypt_decoding(tenant_id: str) -> tuple[RSA.RsaKey, object]: | def get_decrypt_decoding(tenant_id: str) -> tuple[RSA.RsaKey, object]: | ||||
| filepath = os.path.join("privkeys", tenant_id, "private.pem") | filepath = os.path.join("privkeys", tenant_id, "private.pem") | ||||
| cache_key = "tenant_privkey:{hash}".format(hash=hashlib.sha3_256(filepath.encode()).hexdigest()) | |||||
| cache_key = f"tenant_privkey:{hashlib.sha3_256(filepath.encode()).hexdigest()}" | |||||
| private_key = redis_client.get(cache_key) | private_key = redis_client.get(cache_key) | ||||
| if not private_key: | if not private_key: | ||||
| try: | try: | ||||
| private_key = storage.load(filepath) | private_key = storage.load(filepath) | ||||
| except FileNotFoundError: | except FileNotFoundError: | ||||
| raise PrivkeyNotFoundError("Private key not found, tenant_id: {tenant_id}".format(tenant_id=tenant_id)) | |||||
| raise PrivkeyNotFoundError(f"Private key not found, tenant_id: {tenant_id}") | |||||
| redis_client.setex(cache_key, 120, private_key) | redis_client.setex(cache_key, 120, private_key) | ||||
| ) | ) | ||||
| raise | raise | ||||
| except Exception as e: | except Exception as e: | ||||
| logging.exception(f"SendGridClient Unexpected error occurred while sending email to {_to}") | |||||
| logging.exception("SendGridClient Unexpected error occurred while sending email to %s", _to) | |||||
| raise | raise |