| @@ -60,23 +60,15 @@ class InsertExploreAppListApi(Resource): | |||
| site = app.site | |||
| if not site: | |||
| desc = args["desc"] if args["desc"] else "" | |||
| copy_right = args["copyright"] if args["copyright"] else "" | |||
| privacy_policy = args["privacy_policy"] if args["privacy_policy"] else "" | |||
| custom_disclaimer = args["custom_disclaimer"] if args["custom_disclaimer"] else "" | |||
| desc = args["desc"] or "" | |||
| copy_right = args["copyright"] or "" | |||
| privacy_policy = args["privacy_policy"] or "" | |||
| custom_disclaimer = args["custom_disclaimer"] or "" | |||
| else: | |||
| desc = site.description if site.description else args["desc"] if args["desc"] else "" | |||
| copy_right = site.copyright if site.copyright else args["copyright"] if args["copyright"] else "" | |||
| privacy_policy = ( | |||
| site.privacy_policy if site.privacy_policy else args["privacy_policy"] if args["privacy_policy"] else "" | |||
| ) | |||
| custom_disclaimer = ( | |||
| site.custom_disclaimer | |||
| if site.custom_disclaimer | |||
| else args["custom_disclaimer"] | |||
| if args["custom_disclaimer"] | |||
| else "" | |||
| ) | |||
| desc = site.description or args["desc"] or "" | |||
| copy_right = site.copyright or args["copyright"] or "" | |||
| privacy_policy = site.privacy_policy or args["privacy_policy"] or "" | |||
| custom_disclaimer = site.custom_disclaimer or args["custom_disclaimer"] or "" | |||
| recommended_app = RecommendedApp.query.filter(RecommendedApp.app_id == args["app_id"]).first() | |||
| @@ -99,14 +99,10 @@ class ChatMessageTextApi(Resource): | |||
| and app_model.workflow.features_dict | |||
| ): | |||
| text_to_speech = app_model.workflow.features_dict.get("text_to_speech") | |||
| voice = args.get("voice") if args.get("voice") else text_to_speech.get("voice") | |||
| voice = args.get("voice") or text_to_speech.get("voice") | |||
| else: | |||
| try: | |||
| voice = ( | |||
| args.get("voice") | |||
| if args.get("voice") | |||
| else app_model.app_model_config.text_to_speech_dict.get("voice") | |||
| ) | |||
| voice = args.get("voice") or app_model.app_model_config.text_to_speech_dict.get("voice") | |||
| except Exception: | |||
| voice = None | |||
| response = AudioService.transcript_tts(app_model=app_model, text=text, message_id=message_id, voice=voice) | |||
| @@ -101,7 +101,7 @@ def _generate_account(provider: str, user_info: OAuthUserInfo): | |||
| if not account: | |||
| # Create account | |||
| account_name = user_info.name if user_info.name else "Dify" | |||
| account_name = user_info.name or "Dify" | |||
| account = RegisterService.register( | |||
| email=user_info.email, name=account_name, password=None, open_id=user_info.id, provider=provider | |||
| ) | |||
| @@ -550,12 +550,7 @@ class DatasetApiBaseUrlApi(Resource): | |||
| @login_required | |||
| @account_initialization_required | |||
| def get(self): | |||
| return { | |||
| "api_base_url": ( | |||
| dify_config.SERVICE_API_URL if dify_config.SERVICE_API_URL else request.host_url.rstrip("/") | |||
| ) | |||
| + "/v1" | |||
| } | |||
| return {"api_base_url": (dify_config.SERVICE_API_URL or request.host_url.rstrip("/")) + "/v1"} | |||
| class DatasetRetrievalSettingApi(Resource): | |||
| @@ -86,14 +86,10 @@ class ChatTextApi(InstalledAppResource): | |||
| and app_model.workflow.features_dict | |||
| ): | |||
| text_to_speech = app_model.workflow.features_dict.get("text_to_speech") | |||
| voice = args.get("voice") if args.get("voice") else text_to_speech.get("voice") | |||
| voice = args.get("voice") or text_to_speech.get("voice") | |||
| else: | |||
| try: | |||
| voice = ( | |||
| args.get("voice") | |||
| if args.get("voice") | |||
| else app_model.app_model_config.text_to_speech_dict.get("voice") | |||
| ) | |||
| voice = args.get("voice") or app_model.app_model_config.text_to_speech_dict.get("voice") | |||
| except Exception: | |||
| voice = None | |||
| response = AudioService.transcript_tts(app_model=app_model, message_id=message_id, voice=voice, text=text) | |||
| @@ -327,7 +327,7 @@ class ToolApiProviderPreviousTestApi(Resource): | |||
| return ApiToolManageService.test_api_tool_preview( | |||
| current_user.current_tenant_id, | |||
| args["provider_name"] if args["provider_name"] else "", | |||
| args["provider_name"] or "", | |||
| args["tool_name"], | |||
| args["credentials"], | |||
| args["parameters"], | |||
| @@ -84,14 +84,10 @@ class TextApi(Resource): | |||
| and app_model.workflow.features_dict | |||
| ): | |||
| text_to_speech = app_model.workflow.features_dict.get("text_to_speech") | |||
| voice = args.get("voice") if args.get("voice") else text_to_speech.get("voice") | |||
| voice = args.get("voice") or text_to_speech.get("voice") | |||
| else: | |||
| try: | |||
| voice = ( | |||
| args.get("voice") | |||
| if args.get("voice") | |||
| else app_model.app_model_config.text_to_speech_dict.get("voice") | |||
| ) | |||
| voice = args.get("voice") or app_model.app_model_config.text_to_speech_dict.get("voice") | |||
| except Exception: | |||
| voice = None | |||
| response = AudioService.transcript_tts( | |||
| @@ -83,14 +83,10 @@ class TextApi(WebApiResource): | |||
| and app_model.workflow.features_dict | |||
| ): | |||
| text_to_speech = app_model.workflow.features_dict.get("text_to_speech") | |||
| voice = args.get("voice") if args.get("voice") else text_to_speech.get("voice") | |||
| voice = args.get("voice") or text_to_speech.get("voice") | |||
| else: | |||
| try: | |||
| voice = ( | |||
| args.get("voice") | |||
| if args.get("voice") | |||
| else app_model.app_model_config.text_to_speech_dict.get("voice") | |||
| ) | |||
| voice = args.get("voice") or app_model.app_model_config.text_to_speech_dict.get("voice") | |||
| except Exception: | |||
| voice = None | |||
| @@ -256,7 +256,7 @@ class CotAgentRunner(BaseAgentRunner, ABC): | |||
| model=model_instance.model, | |||
| prompt_messages=prompt_messages, | |||
| message=AssistantPromptMessage(content=final_answer), | |||
| usage=llm_usage["usage"] if llm_usage["usage"] else LLMUsage.empty_usage(), | |||
| usage=llm_usage["usage"] or LLMUsage.empty_usage(), | |||
| system_fingerprint="", | |||
| ) | |||
| ), | |||
| @@ -298,7 +298,7 @@ class FunctionCallAgentRunner(BaseAgentRunner): | |||
| model=model_instance.model, | |||
| prompt_messages=prompt_messages, | |||
| message=AssistantPromptMessage(content=final_answer), | |||
| usage=llm_usage["usage"] if llm_usage["usage"] else LLMUsage.empty_usage(), | |||
| usage=llm_usage["usage"] or LLMUsage.empty_usage(), | |||
| system_fingerprint="", | |||
| ) | |||
| ), | |||
| @@ -161,7 +161,7 @@ class AppRunner: | |||
| app_mode=AppMode.value_of(app_record.mode), | |||
| prompt_template_entity=prompt_template_entity, | |||
| inputs=inputs, | |||
| query=query if query else "", | |||
| query=query or "", | |||
| files=files, | |||
| context=context, | |||
| memory=memory, | |||
| @@ -189,7 +189,7 @@ class AppRunner: | |||
| prompt_messages = prompt_transform.get_prompt( | |||
| prompt_template=prompt_template, | |||
| inputs=inputs, | |||
| query=query if query else "", | |||
| query=query or "", | |||
| files=files, | |||
| context=context, | |||
| memory_config=memory_config, | |||
| @@ -238,7 +238,7 @@ class AppRunner: | |||
| model=app_generate_entity.model_conf.model, | |||
| prompt_messages=prompt_messages, | |||
| message=AssistantPromptMessage(content=text), | |||
| usage=usage if usage else LLMUsage.empty_usage(), | |||
| usage=usage or LLMUsage.empty_usage(), | |||
| ), | |||
| ), | |||
| PublishFrom.APPLICATION_MANAGER, | |||
| @@ -351,7 +351,7 @@ class AppRunner: | |||
| tenant_id=tenant_id, | |||
| app_config=app_generate_entity.app_config, | |||
| inputs=inputs, | |||
| query=query if query else "", | |||
| query=query or "", | |||
| message_id=message_id, | |||
| trace_manager=app_generate_entity.trace_manager, | |||
| ) | |||
| @@ -3,6 +3,7 @@ import importlib.util | |||
| import json | |||
| import logging | |||
| import os | |||
| from pathlib import Path | |||
| from typing import Any, Optional | |||
| from pydantic import BaseModel | |||
| @@ -63,8 +64,7 @@ class Extensible: | |||
| builtin_file_path = os.path.join(subdir_path, "__builtin__") | |||
| if os.path.exists(builtin_file_path): | |||
| with open(builtin_file_path, encoding="utf-8") as f: | |||
| position = int(f.read().strip()) | |||
| position = int(Path(builtin_file_path).read_text(encoding="utf-8").strip()) | |||
| position_map[extension_name] = position | |||
| if (extension_name + ".py") not in file_names: | |||
| @@ -39,7 +39,7 @@ class TokenBufferMemory: | |||
| ) | |||
| if message_limit and message_limit > 0: | |||
| message_limit = message_limit if message_limit <= 500 else 500 | |||
| message_limit = min(message_limit, 500) | |||
| else: | |||
| message_limit = 500 | |||
| @@ -449,7 +449,7 @@ if you are not sure about the structure. | |||
| model=real_model, | |||
| prompt_messages=prompt_messages, | |||
| message=prompt_message, | |||
| usage=usage if usage else LLMUsage.empty_usage(), | |||
| usage=usage or LLMUsage.empty_usage(), | |||
| system_fingerprint=system_fingerprint, | |||
| ), | |||
| credentials=credentials, | |||
| @@ -409,7 +409,7 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): | |||
| ), | |||
| ) | |||
| elif isinstance(chunk, ContentBlockDeltaEvent): | |||
| chunk_text = chunk.delta.text if chunk.delta.text else "" | |||
| chunk_text = chunk.delta.text or "" | |||
| full_assistant_content += chunk_text | |||
| # transform assistant message to prompt message | |||
| @@ -213,7 +213,7 @@ class AzureAIStudioLargeLanguageModel(LargeLanguageModel): | |||
| model=real_model, | |||
| prompt_messages=prompt_messages, | |||
| message=prompt_message, | |||
| usage=usage if usage else LLMUsage.empty_usage(), | |||
| usage=usage or LLMUsage.empty_usage(), | |||
| system_fingerprint=system_fingerprint, | |||
| ), | |||
| credentials=credentials, | |||
| @@ -225,7 +225,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): | |||
| continue | |||
| # transform assistant message to prompt message | |||
| text = delta.text if delta.text else "" | |||
| text = delta.text or "" | |||
| assistant_prompt_message = AssistantPromptMessage(content=text) | |||
| full_text += text | |||
| @@ -400,15 +400,13 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): | |||
| continue | |||
| # transform assistant message to prompt message | |||
| assistant_prompt_message = AssistantPromptMessage( | |||
| content=delta.delta.content if delta.delta.content else "", tool_calls=tool_calls | |||
| ) | |||
| assistant_prompt_message = AssistantPromptMessage(content=delta.delta.content or "", tool_calls=tool_calls) | |||
| full_assistant_content += delta.delta.content if delta.delta.content else "" | |||
| full_assistant_content += delta.delta.content or "" | |||
| real_model = chunk.model | |||
| system_fingerprint = chunk.system_fingerprint | |||
| completion += delta.delta.content if delta.delta.content else "" | |||
| completion += delta.delta.content or "" | |||
| yield LLMResultChunk( | |||
| model=real_model, | |||
| @@ -84,7 +84,7 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel): | |||
| ) | |||
| for i in range(len(sentences)) | |||
| ] | |||
| for index, future in enumerate(futures): | |||
| for future in futures: | |||
| yield from future.result().__enter__().iter_bytes(1024) | |||
| else: | |||
| @@ -331,10 +331,10 @@ class BedrockLargeLanguageModel(LargeLanguageModel): | |||
| elif "contentBlockDelta" in chunk: | |||
| delta = chunk["contentBlockDelta"]["delta"] | |||
| if "text" in delta: | |||
| chunk_text = delta["text"] if delta["text"] else "" | |||
| chunk_text = delta["text"] or "" | |||
| full_assistant_content += chunk_text | |||
| assistant_prompt_message = AssistantPromptMessage( | |||
| content=chunk_text if chunk_text else "", | |||
| content=chunk_text or "", | |||
| ) | |||
| index = chunk["contentBlockDelta"]["contentBlockIndex"] | |||
| yield LLMResultChunk( | |||
| @@ -751,7 +751,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel): | |||
| elif model_prefix == "cohere": | |||
| output = response_body.get("generations")[0].get("text") | |||
| prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) | |||
| completion_tokens = self.get_num_tokens(model, credentials, output if output else "") | |||
| completion_tokens = self.get_num_tokens(model, credentials, output or "") | |||
| else: | |||
| raise ValueError(f"Got unknown model prefix {model_prefix} when handling block response") | |||
| @@ -828,7 +828,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel): | |||
| # transform assistant message to prompt message | |||
| assistant_prompt_message = AssistantPromptMessage( | |||
| content=content_delta if content_delta else "", | |||
| content=content_delta or "", | |||
| ) | |||
| index += 1 | |||
| @@ -302,11 +302,11 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel): | |||
| if delta.delta.function_call: | |||
| function_calls = [delta.delta.function_call] | |||
| assistant_message_tool_calls = self._extract_response_tool_calls(function_calls if function_calls else []) | |||
| assistant_message_tool_calls = self._extract_response_tool_calls(function_calls or []) | |||
| # transform assistant message to prompt message | |||
| assistant_prompt_message = AssistantPromptMessage( | |||
| content=delta.delta.content if delta.delta.content else "", tool_calls=assistant_message_tool_calls | |||
| content=delta.delta.content or "", tool_calls=assistant_message_tool_calls | |||
| ) | |||
| if delta.finish_reason is not None: | |||
| @@ -511,7 +511,7 @@ class LocalAILanguageModel(LargeLanguageModel): | |||
| delta = chunk.choices[0] | |||
| # transform assistant message to prompt message | |||
| assistant_prompt_message = AssistantPromptMessage(content=delta.text if delta.text else "", tool_calls=[]) | |||
| assistant_prompt_message = AssistantPromptMessage(content=delta.text or "", tool_calls=[]) | |||
| if delta.finish_reason is not None: | |||
| # temp_assistant_prompt_message is used to calculate usage | |||
| @@ -578,11 +578,11 @@ class LocalAILanguageModel(LargeLanguageModel): | |||
| if delta.delta.function_call: | |||
| function_calls = [delta.delta.function_call] | |||
| assistant_message_tool_calls = self._extract_response_tool_calls(function_calls if function_calls else []) | |||
| assistant_message_tool_calls = self._extract_response_tool_calls(function_calls or []) | |||
| # transform assistant message to prompt message | |||
| assistant_prompt_message = AssistantPromptMessage( | |||
| content=delta.delta.content if delta.delta.content else "", tool_calls=assistant_message_tool_calls | |||
| content=delta.delta.content or "", tool_calls=assistant_message_tool_calls | |||
| ) | |||
| if delta.finish_reason is not None: | |||
| @@ -211,7 +211,7 @@ class MinimaxLargeLanguageModel(LargeLanguageModel): | |||
| index=0, | |||
| message=AssistantPromptMessage(content=message.content, tool_calls=[]), | |||
| usage=usage, | |||
| finish_reason=message.stop_reason if message.stop_reason else None, | |||
| finish_reason=message.stop_reason or None, | |||
| ), | |||
| ) | |||
| elif message.function_call: | |||
| @@ -244,7 +244,7 @@ class MinimaxLargeLanguageModel(LargeLanguageModel): | |||
| delta=LLMResultChunkDelta( | |||
| index=0, | |||
| message=AssistantPromptMessage(content=message.content, tool_calls=[]), | |||
| finish_reason=message.stop_reason if message.stop_reason else None, | |||
| finish_reason=message.stop_reason or None, | |||
| ), | |||
| ) | |||
| @@ -65,7 +65,7 @@ class OllamaEmbeddingModel(TextEmbeddingModel): | |||
| inputs = [] | |||
| used_tokens = 0 | |||
| for i, text in enumerate(texts): | |||
| for text in texts: | |||
| # Here token count is only an approximation based on the GPT2 tokenizer | |||
| num_tokens = self._get_num_tokens_by_gpt2(text) | |||
| @@ -508,7 +508,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): | |||
| continue | |||
| # transform assistant message to prompt message | |||
| text = delta.text if delta.text else "" | |||
| text = delta.text or "" | |||
| assistant_prompt_message = AssistantPromptMessage(content=text) | |||
| full_text += text | |||
| @@ -760,11 +760,9 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): | |||
| final_tool_calls.extend(tool_calls) | |||
| # transform assistant message to prompt message | |||
| assistant_prompt_message = AssistantPromptMessage( | |||
| content=delta.delta.content if delta.delta.content else "", tool_calls=tool_calls | |||
| ) | |||
| assistant_prompt_message = AssistantPromptMessage(content=delta.delta.content or "", tool_calls=tool_calls) | |||
| full_assistant_content += delta.delta.content if delta.delta.content else "" | |||
| full_assistant_content += delta.delta.content or "" | |||
| if has_finish_reason: | |||
| final_chunk = LLMResultChunk( | |||
| @@ -88,7 +88,7 @@ class OpenAIText2SpeechModel(_CommonOpenAI, TTSModel): | |||
| ) | |||
| for i in range(len(sentences)) | |||
| ] | |||
| for index, future in enumerate(futures): | |||
| for future in futures: | |||
| yield from future.result().__enter__().iter_bytes(1024) | |||
| else: | |||
| @@ -179,9 +179,9 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel): | |||
| features = [] | |||
| function_calling_type = credentials.get("function_calling_type", "no_call") | |||
| if function_calling_type in ["function_call"]: | |||
| if function_calling_type == "function_call": | |||
| features.append(ModelFeature.TOOL_CALL) | |||
| elif function_calling_type in ["tool_call"]: | |||
| elif function_calling_type == "tool_call": | |||
| features.append(ModelFeature.MULTI_TOOL_CALL) | |||
| stream_function_calling = credentials.get("stream_function_calling", "supported") | |||
| @@ -179,7 +179,7 @@ class OpenLLMLargeLanguageModel(LargeLanguageModel): | |||
| index=0, | |||
| message=AssistantPromptMessage(content=message.content, tool_calls=[]), | |||
| usage=usage, | |||
| finish_reason=message.stop_reason if message.stop_reason else None, | |||
| finish_reason=message.stop_reason or None, | |||
| ), | |||
| ) | |||
| else: | |||
| @@ -189,7 +189,7 @@ class OpenLLMLargeLanguageModel(LargeLanguageModel): | |||
| delta=LLMResultChunkDelta( | |||
| index=0, | |||
| message=AssistantPromptMessage(content=message.content, tool_calls=[]), | |||
| finish_reason=message.stop_reason if message.stop_reason else None, | |||
| finish_reason=message.stop_reason or None, | |||
| ), | |||
| ) | |||
| @@ -106,7 +106,7 @@ class OpenLLMGenerate: | |||
| timeout = 120 | |||
| data = { | |||
| "stop": stop if stop else [], | |||
| "stop": stop or [], | |||
| "prompt": "\n".join([message.content for message in prompt_messages]), | |||
| "llm_config": default_llm_config, | |||
| } | |||
| @@ -214,7 +214,7 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel): | |||
| index += 1 | |||
| assistant_prompt_message = AssistantPromptMessage(content=output if output else "") | |||
| assistant_prompt_message = AssistantPromptMessage(content=output or "") | |||
| if index < prediction_output_length: | |||
| yield LLMResultChunk( | |||
| @@ -1,5 +1,6 @@ | |||
| import json | |||
| import logging | |||
| import operator | |||
| from typing import Any, Optional | |||
| import boto3 | |||
| @@ -94,7 +95,7 @@ class SageMakerRerankModel(RerankModel): | |||
| for idx in range(len(scores)): | |||
| candidate_docs.append({"content": docs[idx], "score": scores[idx]}) | |||
| sorted(candidate_docs, key=lambda x: x["score"], reverse=True) | |||
| sorted(candidate_docs, key=operator.itemgetter("score"), reverse=True) | |||
| line = 3 | |||
| rerank_documents = [] | |||
| @@ -260,7 +260,7 @@ class SageMakerText2SpeechModel(TTSModel): | |||
| for payload in payloads | |||
| ] | |||
| for index, future in enumerate(futures): | |||
| for future in futures: | |||
| resp = future.result() | |||
| audio_bytes = requests.get(resp.get("s3_presign_url")).content | |||
| for i in range(0, len(audio_bytes), 1024): | |||
| @@ -220,7 +220,7 @@ class SparkLargeLanguageModel(LargeLanguageModel): | |||
| delta = content | |||
| assistant_prompt_message = AssistantPromptMessage( | |||
| content=delta if delta else "", | |||
| content=delta or "", | |||
| ) | |||
| prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) | |||
| @@ -1,6 +1,7 @@ | |||
| import base64 | |||
| import hashlib | |||
| import hmac | |||
| import operator | |||
| import time | |||
| import requests | |||
| @@ -127,7 +128,7 @@ class FlashRecognizer: | |||
| return s | |||
| def _build_req_with_signature(self, secret_key, params, header): | |||
| query = sorted(params.items(), key=lambda d: d[0]) | |||
| query = sorted(params.items(), key=operator.itemgetter(0)) | |||
| signstr = self._format_sign_string(query) | |||
| signature = self._sign(signstr, secret_key) | |||
| header["Authorization"] = signature | |||
| @@ -4,6 +4,7 @@ import tempfile | |||
| import uuid | |||
| from collections.abc import Generator | |||
| from http import HTTPStatus | |||
| from pathlib import Path | |||
| from typing import Optional, Union, cast | |||
| from dashscope import Generation, MultiModalConversation, get_tokenizer | |||
| @@ -454,8 +455,7 @@ class TongyiLargeLanguageModel(LargeLanguageModel): | |||
| file_path = os.path.join(temp_dir, f"{uuid.uuid4()}.{mime_type.split('/')[1]}") | |||
| with open(file_path, "wb") as image_file: | |||
| image_file.write(base64.b64decode(encoded_string)) | |||
| Path(file_path).write_bytes(base64.b64decode(encoded_string)) | |||
| return f"file://{file_path}" | |||
| @@ -368,11 +368,9 @@ class UpstageLargeLanguageModel(_CommonUpstage, LargeLanguageModel): | |||
| final_tool_calls.extend(tool_calls) | |||
| # transform assistant message to prompt message | |||
| assistant_prompt_message = AssistantPromptMessage( | |||
| content=delta.delta.content if delta.delta.content else "", tool_calls=tool_calls | |||
| ) | |||
| assistant_prompt_message = AssistantPromptMessage(content=delta.delta.content or "", tool_calls=tool_calls) | |||
| full_assistant_content += delta.delta.content if delta.delta.content else "" | |||
| full_assistant_content += delta.delta.content or "" | |||
| if has_finish_reason: | |||
| final_chunk = LLMResultChunk( | |||
| @@ -231,10 +231,10 @@ class VertexAiLargeLanguageModel(LargeLanguageModel): | |||
| ), | |||
| ) | |||
| elif isinstance(chunk, ContentBlockDeltaEvent): | |||
| chunk_text = chunk.delta.text if chunk.delta.text else "" | |||
| chunk_text = chunk.delta.text or "" | |||
| full_assistant_content += chunk_text | |||
| assistant_prompt_message = AssistantPromptMessage( | |||
| content=chunk_text if chunk_text else "", | |||
| content=chunk_text or "", | |||
| ) | |||
| index = chunk.index | |||
| yield LLMResultChunk( | |||
| @@ -1,5 +1,6 @@ | |||
| # coding : utf-8 | |||
| import datetime | |||
| from itertools import starmap | |||
| import pytz | |||
| @@ -48,7 +49,7 @@ class SignResult: | |||
| self.authorization = "" | |||
| def __str__(self): | |||
| return "\n".join(["{}:{}".format(*item) for item in self.__dict__.items()]) | |||
| return "\n".join(list(starmap("{}:{}".format, self.__dict__.items()))) | |||
| class Credentials: | |||
| @@ -1,5 +1,6 @@ | |||
| import hashlib | |||
| import hmac | |||
| import operator | |||
| from functools import reduce | |||
| from urllib.parse import quote | |||
| @@ -40,4 +41,4 @@ class Util: | |||
| if len(hv) == 1: | |||
| hv = "0" + hv | |||
| lst.append(hv) | |||
| return reduce(lambda x, y: x + y, lst) | |||
| return reduce(operator.add, lst) | |||
| @@ -174,9 +174,7 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel): | |||
| prompt_messages=prompt_messages, | |||
| delta=LLMResultChunkDelta( | |||
| index=index, | |||
| message=AssistantPromptMessage( | |||
| content=message["content"] if message["content"] else "", tool_calls=[] | |||
| ), | |||
| message=AssistantPromptMessage(content=message["content"] or "", tool_calls=[]), | |||
| usage=usage, | |||
| finish_reason=choice.get("finish_reason"), | |||
| ), | |||
| @@ -208,7 +206,7 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel): | |||
| model=model, | |||
| prompt_messages=prompt_messages, | |||
| message=AssistantPromptMessage( | |||
| content=message["content"] if message["content"] else "", | |||
| content=message["content"] or "", | |||
| tool_calls=tool_calls, | |||
| ), | |||
| usage=self._calc_response_usage( | |||
| @@ -284,7 +282,7 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel): | |||
| model=model, | |||
| prompt_messages=prompt_messages, | |||
| message=AssistantPromptMessage( | |||
| content=message.content if message.content else "", | |||
| content=message.content or "", | |||
| tool_calls=tool_calls, | |||
| ), | |||
| usage=self._calc_response_usage( | |||
| @@ -199,7 +199,7 @@ class ErnieBotLargeLanguageModel(LargeLanguageModel): | |||
| secret_key=credentials["secret_key"], | |||
| ) | |||
| user = user if user else "ErnieBotDefault" | |||
| user = user or "ErnieBotDefault" | |||
| # convert prompt messages to baichuan messages | |||
| messages = [ | |||
| @@ -289,7 +289,7 @@ class ErnieBotLargeLanguageModel(LargeLanguageModel): | |||
| index=0, | |||
| message=AssistantPromptMessage(content=message.content, tool_calls=[]), | |||
| usage=usage, | |||
| finish_reason=message.stop_reason if message.stop_reason else None, | |||
| finish_reason=message.stop_reason or None, | |||
| ), | |||
| ) | |||
| else: | |||
| @@ -299,7 +299,7 @@ class ErnieBotLargeLanguageModel(LargeLanguageModel): | |||
| delta=LLMResultChunkDelta( | |||
| index=0, | |||
| message=AssistantPromptMessage(content=message.content, tool_calls=[]), | |||
| finish_reason=message.stop_reason if message.stop_reason else None, | |||
| finish_reason=message.stop_reason or None, | |||
| ), | |||
| ) | |||
| @@ -85,7 +85,7 @@ class WenxinTextEmbeddingModel(TextEmbeddingModel): | |||
| api_key = credentials["api_key"] | |||
| secret_key = credentials["secret_key"] | |||
| embedding: TextEmbedding = self._create_text_embedding(api_key, secret_key) | |||
| user = user if user else "ErnieBotDefault" | |||
| user = user or "ErnieBotDefault" | |||
| context_size = self._get_context_size(model, credentials) | |||
| max_chunks = self._get_max_chunks(model, credentials) | |||
| @@ -589,7 +589,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): | |||
| # convert tool call to assistant message tool call | |||
| tool_calls = assistant_message.tool_calls | |||
| assistant_prompt_message_tool_calls = self._extract_response_tool_calls(tool_calls if tool_calls else []) | |||
| assistant_prompt_message_tool_calls = self._extract_response_tool_calls(tool_calls or []) | |||
| function_call = assistant_message.function_call | |||
| if function_call: | |||
| assistant_prompt_message_tool_calls += [self._extract_response_function_call(function_call)] | |||
| @@ -652,7 +652,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): | |||
| # transform assistant message to prompt message | |||
| assistant_prompt_message = AssistantPromptMessage( | |||
| content=delta.delta.content if delta.delta.content else "", tool_calls=assistant_message_tool_calls | |||
| content=delta.delta.content or "", tool_calls=assistant_message_tool_calls | |||
| ) | |||
| if delta.finish_reason is not None: | |||
| @@ -749,7 +749,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): | |||
| delta = chunk.choices[0] | |||
| # transform assistant message to prompt message | |||
| assistant_prompt_message = AssistantPromptMessage(content=delta.text if delta.text else "", tool_calls=[]) | |||
| assistant_prompt_message = AssistantPromptMessage(content=delta.text or "", tool_calls=[]) | |||
| if delta.finish_reason is not None: | |||
| # temp_assistant_prompt_message is used to calculate usage | |||
| @@ -215,7 +215,7 @@ class XinferenceText2SpeechModel(TTSModel): | |||
| for i in range(len(sentences)) | |||
| ] | |||
| for index, future in enumerate(futures): | |||
| for future in futures: | |||
| response = future.result() | |||
| for i in range(0, len(response), 1024): | |||
| yield response[i : i + 1024] | |||
| @@ -414,10 +414,10 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): | |||
| # transform assistant message to prompt message | |||
| assistant_prompt_message = AssistantPromptMessage( | |||
| content=delta.delta.content if delta.delta.content else "", tool_calls=assistant_tool_calls | |||
| content=delta.delta.content or "", tool_calls=assistant_tool_calls | |||
| ) | |||
| full_assistant_content += delta.delta.content if delta.delta.content else "" | |||
| full_assistant_content += delta.delta.content or "" | |||
| if delta.finish_reason is not None and chunk.usage is not None: | |||
| completion_tokens = chunk.usage.completion_tokens | |||
| @@ -30,6 +30,8 @@ def _merge_map(map1: Mapping, map2: Mapping) -> Mapping: | |||
| return {key: val for key, val in merged.items() if val is not None} | |||
| from itertools import starmap | |||
| from httpx._config import DEFAULT_TIMEOUT_CONFIG as HTTPX_DEFAULT_TIMEOUT | |||
| ZHIPUAI_DEFAULT_TIMEOUT = httpx.Timeout(timeout=300.0, connect=8.0) | |||
| @@ -159,7 +161,7 @@ class HttpClient: | |||
| return [(key, str_data)] | |||
| def _make_multipartform(self, data: Mapping[object, object]) -> dict[str, object]: | |||
| items = flatten([self._object_to_formdata(k, v) for k, v in data.items()]) | |||
| items = flatten(list(starmap(self._object_to_formdata, data.items()))) | |||
| serialized: dict[str, object] = {} | |||
| for key, value in items: | |||
| @@ -65,7 +65,7 @@ class LangFuseDataTrace(BaseTraceInstance): | |||
| self.generate_name_trace(trace_info) | |||
| def workflow_trace(self, trace_info: WorkflowTraceInfo): | |||
| trace_id = trace_info.workflow_app_log_id if trace_info.workflow_app_log_id else trace_info.workflow_run_id | |||
| trace_id = trace_info.workflow_app_log_id or trace_info.workflow_run_id | |||
| user_id = trace_info.metadata.get("user_id") | |||
| if trace_info.message_id: | |||
| trace_id = trace_info.message_id | |||
| @@ -84,7 +84,7 @@ class LangFuseDataTrace(BaseTraceInstance): | |||
| ) | |||
| self.add_trace(langfuse_trace_data=trace_data) | |||
| workflow_span_data = LangfuseSpan( | |||
| id=(trace_info.workflow_app_log_id if trace_info.workflow_app_log_id else trace_info.workflow_run_id), | |||
| id=(trace_info.workflow_app_log_id or trace_info.workflow_run_id), | |||
| name=TraceTaskName.WORKFLOW_TRACE.value, | |||
| input=trace_info.workflow_run_inputs, | |||
| output=trace_info.workflow_run_outputs, | |||
| @@ -93,7 +93,7 @@ class LangFuseDataTrace(BaseTraceInstance): | |||
| end_time=trace_info.end_time, | |||
| metadata=trace_info.metadata, | |||
| level=LevelEnum.DEFAULT if trace_info.error == "" else LevelEnum.ERROR, | |||
| status_message=trace_info.error if trace_info.error else "", | |||
| status_message=trace_info.error or "", | |||
| ) | |||
| self.add_span(langfuse_span_data=workflow_span_data) | |||
| else: | |||
| @@ -143,7 +143,7 @@ class LangFuseDataTrace(BaseTraceInstance): | |||
| else: | |||
| inputs = json.loads(node_execution.inputs) if node_execution.inputs else {} | |||
| outputs = json.loads(node_execution.outputs) if node_execution.outputs else {} | |||
| created_at = node_execution.created_at if node_execution.created_at else datetime.now() | |||
| created_at = node_execution.created_at or datetime.now() | |||
| elapsed_time = node_execution.elapsed_time | |||
| finished_at = created_at + timedelta(seconds=elapsed_time) | |||
| @@ -172,10 +172,8 @@ class LangFuseDataTrace(BaseTraceInstance): | |||
| end_time=finished_at, | |||
| metadata=metadata, | |||
| level=(LevelEnum.DEFAULT if status == "succeeded" else LevelEnum.ERROR), | |||
| status_message=trace_info.error if trace_info.error else "", | |||
| parent_observation_id=( | |||
| trace_info.workflow_app_log_id if trace_info.workflow_app_log_id else trace_info.workflow_run_id | |||
| ), | |||
| status_message=trace_info.error or "", | |||
| parent_observation_id=(trace_info.workflow_app_log_id or trace_info.workflow_run_id), | |||
| ) | |||
| else: | |||
| span_data = LangfuseSpan( | |||
| @@ -188,7 +186,7 @@ class LangFuseDataTrace(BaseTraceInstance): | |||
| end_time=finished_at, | |||
| metadata=metadata, | |||
| level=(LevelEnum.DEFAULT if status == "succeeded" else LevelEnum.ERROR), | |||
| status_message=trace_info.error if trace_info.error else "", | |||
| status_message=trace_info.error or "", | |||
| ) | |||
| self.add_span(langfuse_span_data=span_data) | |||
| @@ -212,7 +210,7 @@ class LangFuseDataTrace(BaseTraceInstance): | |||
| output=outputs, | |||
| metadata=metadata, | |||
| level=(LevelEnum.DEFAULT if status == "succeeded" else LevelEnum.ERROR), | |||
| status_message=trace_info.error if trace_info.error else "", | |||
| status_message=trace_info.error or "", | |||
| usage=generation_usage, | |||
| ) | |||
| @@ -277,7 +275,7 @@ class LangFuseDataTrace(BaseTraceInstance): | |||
| output=message_data.answer, | |||
| metadata=metadata, | |||
| level=(LevelEnum.DEFAULT if message_data.status != "error" else LevelEnum.ERROR), | |||
| status_message=message_data.error if message_data.error else "", | |||
| status_message=message_data.error or "", | |||
| usage=generation_usage, | |||
| ) | |||
| @@ -319,7 +317,7 @@ class LangFuseDataTrace(BaseTraceInstance): | |||
| end_time=trace_info.end_time, | |||
| metadata=trace_info.metadata, | |||
| level=(LevelEnum.DEFAULT if message_data.status != "error" else LevelEnum.ERROR), | |||
| status_message=message_data.error if message_data.error else "", | |||
| status_message=message_data.error or "", | |||
| usage=generation_usage, | |||
| ) | |||
| @@ -82,7 +82,7 @@ class LangSmithDataTrace(BaseTraceInstance): | |||
| langsmith_run = LangSmithRunModel( | |||
| file_list=trace_info.file_list, | |||
| total_tokens=trace_info.total_tokens, | |||
| id=trace_info.workflow_app_log_id if trace_info.workflow_app_log_id else trace_info.workflow_run_id, | |||
| id=trace_info.workflow_app_log_id or trace_info.workflow_run_id, | |||
| name=TraceTaskName.WORKFLOW_TRACE.value, | |||
| inputs=trace_info.workflow_run_inputs, | |||
| run_type=LangSmithRunType.tool, | |||
| @@ -94,7 +94,7 @@ class LangSmithDataTrace(BaseTraceInstance): | |||
| }, | |||
| error=trace_info.error, | |||
| tags=["workflow"], | |||
| parent_run_id=trace_info.message_id if trace_info.message_id else None, | |||
| parent_run_id=trace_info.message_id or None, | |||
| ) | |||
| self.add_run(langsmith_run) | |||
| @@ -133,7 +133,7 @@ class LangSmithDataTrace(BaseTraceInstance): | |||
| else: | |||
| inputs = json.loads(node_execution.inputs) if node_execution.inputs else {} | |||
| outputs = json.loads(node_execution.outputs) if node_execution.outputs else {} | |||
| created_at = node_execution.created_at if node_execution.created_at else datetime.now() | |||
| created_at = node_execution.created_at or datetime.now() | |||
| elapsed_time = node_execution.elapsed_time | |||
| finished_at = created_at + timedelta(seconds=elapsed_time) | |||
| @@ -180,9 +180,7 @@ class LangSmithDataTrace(BaseTraceInstance): | |||
| extra={ | |||
| "metadata": metadata, | |||
| }, | |||
| parent_run_id=trace_info.workflow_app_log_id | |||
| if trace_info.workflow_app_log_id | |||
| else trace_info.workflow_run_id, | |||
| parent_run_id=trace_info.workflow_app_log_id or trace_info.workflow_run_id, | |||
| tags=["node_execution"], | |||
| ) | |||
| @@ -354,11 +354,11 @@ class TraceTask: | |||
| workflow_run_inputs = json.loads(workflow_run.inputs) if workflow_run.inputs else {} | |||
| workflow_run_outputs = json.loads(workflow_run.outputs) if workflow_run.outputs else {} | |||
| workflow_run_version = workflow_run.version | |||
| error = workflow_run.error if workflow_run.error else "" | |||
| error = workflow_run.error or "" | |||
| total_tokens = workflow_run.total_tokens | |||
| file_list = workflow_run_inputs.get("sys.file") if workflow_run_inputs.get("sys.file") else [] | |||
| file_list = workflow_run_inputs.get("sys.file") or [] | |||
| query = workflow_run_inputs.get("query") or workflow_run_inputs.get("sys.query") or "" | |||
| # get workflow_app_log_id | |||
| @@ -452,7 +452,7 @@ class TraceTask: | |||
| message_tokens=message_tokens, | |||
| answer_tokens=message_data.answer_tokens, | |||
| total_tokens=message_tokens + message_data.answer_tokens, | |||
| error=message_data.error if message_data.error else "", | |||
| error=message_data.error or "", | |||
| inputs=inputs, | |||
| outputs=message_data.answer, | |||
| file_list=file_list, | |||
| @@ -487,7 +487,7 @@ class TraceTask: | |||
| workflow_app_log_id = str(workflow_app_log_data.id) if workflow_app_log_data else None | |||
| moderation_trace_info = ModerationTraceInfo( | |||
| message_id=workflow_app_log_id if workflow_app_log_id else message_id, | |||
| message_id=workflow_app_log_id or message_id, | |||
| inputs=inputs, | |||
| message_data=message_data.to_dict(), | |||
| flagged=moderation_result.flagged, | |||
| @@ -527,7 +527,7 @@ class TraceTask: | |||
| workflow_app_log_id = str(workflow_app_log_data.id) if workflow_app_log_data else None | |||
| suggested_question_trace_info = SuggestedQuestionTraceInfo( | |||
| message_id=workflow_app_log_id if workflow_app_log_id else message_id, | |||
| message_id=workflow_app_log_id or message_id, | |||
| message_data=message_data.to_dict(), | |||
| inputs=message_data.message, | |||
| outputs=message_data.answer, | |||
| @@ -569,7 +569,7 @@ class TraceTask: | |||
| dataset_retrieval_trace_info = DatasetRetrievalTraceInfo( | |||
| message_id=message_id, | |||
| inputs=message_data.query if message_data.query else message_data.inputs, | |||
| inputs=message_data.query or message_data.inputs, | |||
| documents=[doc.model_dump() for doc in documents], | |||
| start_time=timer.get("start"), | |||
| end_time=timer.get("end"), | |||
| @@ -695,8 +695,7 @@ class TraceQueueManager: | |||
| self.start_timer() | |||
| def add_trace_task(self, trace_task: TraceTask): | |||
| global trace_manager_timer | |||
| global trace_manager_queue | |||
| global trace_manager_timer, trace_manager_queue | |||
| try: | |||
| if self.trace_instance: | |||
| trace_task.app_id = self.app_id | |||
| @@ -112,11 +112,11 @@ class SimplePromptTransform(PromptTransform): | |||
| for v in prompt_template_config["special_variable_keys"]: | |||
| # support #context#, #query# and #histories# | |||
| if v == "#context#": | |||
| variables["#context#"] = context if context else "" | |||
| variables["#context#"] = context or "" | |||
| elif v == "#query#": | |||
| variables["#query#"] = query if query else "" | |||
| variables["#query#"] = query or "" | |||
| elif v == "#histories#": | |||
| variables["#histories#"] = histories if histories else "" | |||
| variables["#histories#"] = histories or "" | |||
| prompt_template = prompt_template_config["prompt_template"] | |||
| prompt = prompt_template.format(variables) | |||
| @@ -34,7 +34,7 @@ class BaseKeyword(ABC): | |||
| raise NotImplementedError | |||
| def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]: | |||
| for text in texts[:]: | |||
| for text in texts.copy(): | |||
| doc_id = text.metadata["doc_id"] | |||
| exists_duplicate_node = self.text_exists(doc_id) | |||
| if exists_duplicate_node: | |||
| @@ -239,7 +239,7 @@ class AnalyticdbVector(BaseVector): | |||
| def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: | |||
| from alibabacloud_gpdb20160503 import models as gpdb_20160503_models | |||
| score_threshold = kwargs.get("score_threshold", 0.0) if kwargs.get("score_threshold", 0.0) else 0.0 | |||
| score_threshold = kwargs.get("score_threshold", 0.0) | |||
| request = gpdb_20160503_models.QueryCollectionDataRequest( | |||
| dbinstance_id=self.config.instance_id, | |||
| region_id=self.config.region_id, | |||
| @@ -267,7 +267,7 @@ class AnalyticdbVector(BaseVector): | |||
| def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: | |||
| from alibabacloud_gpdb20160503 import models as gpdb_20160503_models | |||
| score_threshold = kwargs.get("score_threshold", 0.0) if kwargs.get("score_threshold", 0.0) else 0.0 | |||
| score_threshold = kwargs.get("score_threshold", 0.0) | |||
| request = gpdb_20160503_models.QueryCollectionDataRequest( | |||
| dbinstance_id=self.config.instance_id, | |||
| region_id=self.config.region_id, | |||
| @@ -92,7 +92,7 @@ class ChromaVector(BaseVector): | |||
| def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: | |||
| collection = self._client.get_or_create_collection(self._collection_name) | |||
| results: QueryResult = collection.query(query_embeddings=query_vector, n_results=kwargs.get("top_k", 4)) | |||
| score_threshold = kwargs.get("score_threshold", 0.0) if kwargs.get("score_threshold", 0.0) else 0.0 | |||
| score_threshold = kwargs.get("score_threshold", 0.0) | |||
| ids: list[str] = results["ids"][0] | |||
| documents: list[str] = results["documents"][0] | |||
| @@ -86,8 +86,8 @@ class ElasticSearchVector(BaseVector): | |||
| id=uuids[i], | |||
| document={ | |||
| Field.CONTENT_KEY.value: documents[i].page_content, | |||
| Field.VECTOR.value: embeddings[i] if embeddings[i] else None, | |||
| Field.METADATA_KEY.value: documents[i].metadata if documents[i].metadata else {}, | |||
| Field.VECTOR.value: embeddings[i] or None, | |||
| Field.METADATA_KEY.value: documents[i].metadata or {}, | |||
| }, | |||
| ) | |||
| self._client.indices.refresh(index=self._collection_name) | |||
| @@ -131,7 +131,7 @@ class ElasticSearchVector(BaseVector): | |||
| docs = [] | |||
| for doc, score in docs_and_scores: | |||
| score_threshold = kwargs.get("score_threshold", 0.0) if kwargs.get("score_threshold", 0.0) else 0.0 | |||
| score_threshold = kwargs.get("score_threshold", 0.0) | |||
| if score > score_threshold: | |||
| doc.metadata["score"] = score | |||
| docs.append(doc) | |||
| @@ -141,7 +141,7 @@ class MilvusVector(BaseVector): | |||
| for result in results[0]: | |||
| metadata = result["entity"].get(Field.METADATA_KEY.value) | |||
| metadata["score"] = result["distance"] | |||
| score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0 | |||
| score_threshold = kwargs.get("score_threshold", 0.0) | |||
| if result["distance"] > score_threshold: | |||
| doc = Document(page_content=result["entity"].get(Field.CONTENT_KEY.value), metadata=metadata) | |||
| docs.append(doc) | |||
| @@ -122,7 +122,7 @@ class MyScaleVector(BaseVector): | |||
| def _search(self, dist: str, order: SortOrder, **kwargs: Any) -> list[Document]: | |||
| top_k = kwargs.get("top_k", 5) | |||
| score_threshold = kwargs.get("score_threshold") or 0.0 | |||
| score_threshold = kwargs.get("score_threshold", 0.0) | |||
| where_str = ( | |||
| f"WHERE dist < {1 - score_threshold}" | |||
| if self._metric.upper() == "COSINE" and order == SortOrder.ASC and score_threshold > 0.0 | |||
| @@ -170,7 +170,7 @@ class OpenSearchVector(BaseVector): | |||
| metadata = {} | |||
| metadata["score"] = hit["_score"] | |||
| score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0 | |||
| score_threshold = kwargs.get("score_threshold", 0.0) | |||
| if hit["_score"] > score_threshold: | |||
| doc = Document(page_content=hit["_source"].get(Field.CONTENT_KEY.value), metadata=metadata) | |||
| docs.append(doc) | |||
| @@ -200,7 +200,7 @@ class OracleVector(BaseVector): | |||
| [numpy.array(query_vector)], | |||
| ) | |||
| docs = [] | |||
| score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0 | |||
| score_threshold = kwargs.get("score_threshold", 0.0) | |||
| for record in cur: | |||
| metadata, text, distance = record | |||
| score = 1 - distance | |||
| @@ -212,7 +212,7 @@ class OracleVector(BaseVector): | |||
| def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: | |||
| top_k = kwargs.get("top_k", 5) | |||
| # just not implement fetch by score_threshold now, may be later | |||
| score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0 | |||
| score_threshold = kwargs.get("score_threshold", 0.0) | |||
| if len(query) > 0: | |||
| # Check which language the query is in | |||
| zh_pattern = re.compile("[\u4e00-\u9fa5]+") | |||
| @@ -198,7 +198,7 @@ class PGVectoRS(BaseVector): | |||
| metadata = record.meta | |||
| score = 1 - dis | |||
| metadata["score"] = score | |||
| score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0 | |||
| score_threshold = kwargs.get("score_threshold", 0.0) | |||
| if score > score_threshold: | |||
| doc = Document(page_content=record.text, metadata=metadata) | |||
| docs.append(doc) | |||
| @@ -144,7 +144,7 @@ class PGVector(BaseVector): | |||
| (json.dumps(query_vector),), | |||
| ) | |||
| docs = [] | |||
| score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0 | |||
| score_threshold = kwargs.get("score_threshold", 0.0) | |||
| for record in cur: | |||
| metadata, text, distance = record | |||
| score = 1 - distance | |||
| @@ -339,7 +339,7 @@ class QdrantVector(BaseVector): | |||
| for result in results: | |||
| metadata = result.payload.get(Field.METADATA_KEY.value) or {} | |||
| # duplicate check score threshold | |||
| score_threshold = kwargs.get("score_threshold", 0.0) if kwargs.get("score_threshold", 0.0) else 0.0 | |||
| score_threshold = kwargs.get("score_threshold", 0.0) | |||
| if result.score > score_threshold: | |||
| metadata["score"] = result.score | |||
| doc = Document( | |||
| @@ -230,7 +230,7 @@ class RelytVector(BaseVector): | |||
| # Organize results. | |||
| docs = [] | |||
| for document, score in results: | |||
| score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0 | |||
| score_threshold = kwargs.get("score_threshold", 0.0) | |||
| if 1 - score > score_threshold: | |||
| docs.append(document) | |||
| return docs | |||
| @@ -153,7 +153,7 @@ class TencentVector(BaseVector): | |||
| limit=kwargs.get("top_k", 4), | |||
| timeout=self._client_config.timeout, | |||
| ) | |||
| score_threshold = kwargs.get("score_threshold", 0.0) if kwargs.get("score_threshold", 0.0) else 0.0 | |||
| score_threshold = kwargs.get("score_threshold", 0.0) | |||
| return self._get_search_res(res, score_threshold) | |||
| def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: | |||
| @@ -185,7 +185,7 @@ class TiDBVector(BaseVector): | |||
| def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: | |||
| top_k = kwargs.get("top_k", 5) | |||
| score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0 | |||
| score_threshold = kwargs.get("score_threshold", 0.0) | |||
| filter = kwargs.get("filter") | |||
| distance = 1 - score_threshold | |||
| @@ -49,7 +49,7 @@ class BaseVector(ABC): | |||
| raise NotImplementedError | |||
| def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]: | |||
| for text in texts[:]: | |||
| for text in texts.copy(): | |||
| doc_id = text.metadata["doc_id"] | |||
| exists_duplicate_node = self.text_exists(doc_id) | |||
| if exists_duplicate_node: | |||
| @@ -153,7 +153,7 @@ class Vector: | |||
| return CacheEmbedding(embedding_model) | |||
| def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]: | |||
| for text in texts[:]: | |||
| for text in texts.copy(): | |||
| doc_id = text.metadata["doc_id"] | |||
| exists_duplicate_node = self.text_exists(doc_id) | |||
| if exists_duplicate_node: | |||
| @@ -205,7 +205,7 @@ class WeaviateVector(BaseVector): | |||
| docs = [] | |||
| for doc, score in docs_and_scores: | |||
| score_threshold = kwargs.get("score_threshold", 0.0) if kwargs.get("score_threshold", 0.0) else 0.0 | |||
| score_threshold = kwargs.get("score_threshold", 0.0) | |||
| # check score threshold | |||
| if score > score_threshold: | |||
| doc.metadata["score"] = score | |||
| @@ -12,7 +12,7 @@ import mimetypes | |||
| from abc import ABC, abstractmethod | |||
| from collections.abc import Generator, Iterable, Mapping | |||
| from io import BufferedReader, BytesIO | |||
| from pathlib import PurePath | |||
| from pathlib import Path, PurePath | |||
| from typing import Any, Optional, Union | |||
| from pydantic import BaseModel, ConfigDict, model_validator | |||
| @@ -56,8 +56,7 @@ class Blob(BaseModel): | |||
| def as_string(self) -> str: | |||
| """Read data as a string.""" | |||
| if self.data is None and self.path: | |||
| with open(str(self.path), encoding=self.encoding) as f: | |||
| return f.read() | |||
| return Path(str(self.path)).read_text(encoding=self.encoding) | |||
| elif isinstance(self.data, bytes): | |||
| return self.data.decode(self.encoding) | |||
| elif isinstance(self.data, str): | |||
| @@ -72,8 +71,7 @@ class Blob(BaseModel): | |||
| elif isinstance(self.data, str): | |||
| return self.data.encode(self.encoding) | |||
| elif self.data is None and self.path: | |||
| with open(str(self.path), "rb") as f: | |||
| return f.read() | |||
| return Path(str(self.path)).read_bytes() | |||
| else: | |||
| raise ValueError(f"Unable to get bytes for blob {self}") | |||
| @@ -68,8 +68,7 @@ class ExtractProcessor: | |||
| suffix = "." + re.search(r"\.(\w+)$", filename).group(1) | |||
| file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" | |||
| with open(file_path, "wb") as file: | |||
| file.write(response.content) | |||
| Path(file_path).write_bytes(response.content) | |||
| extract_setting = ExtractSetting(datasource_type="upload_file", document_model="text_model") | |||
| if return_text: | |||
| delimiter = "\n" | |||
| @@ -111,7 +110,7 @@ class ExtractProcessor: | |||
| ) | |||
| elif file_extension in [".htm", ".html"]: | |||
| extractor = HtmlExtractor(file_path) | |||
| elif file_extension in [".docx"]: | |||
| elif file_extension == ".docx": | |||
| extractor = WordExtractor(file_path, upload_file.tenant_id, upload_file.created_by) | |||
| elif file_extension == ".csv": | |||
| extractor = CSVExtractor(file_path, autodetect_encoding=True) | |||
| @@ -143,7 +142,7 @@ class ExtractProcessor: | |||
| extractor = MarkdownExtractor(file_path, autodetect_encoding=True) | |||
| elif file_extension in [".htm", ".html"]: | |||
| extractor = HtmlExtractor(file_path) | |||
| elif file_extension in [".docx"]: | |||
| elif file_extension == ".docx": | |||
| extractor = WordExtractor(file_path, upload_file.tenant_id, upload_file.created_by) | |||
| elif file_extension == ".csv": | |||
| extractor = CSVExtractor(file_path, autodetect_encoding=True) | |||
| @@ -1,6 +1,7 @@ | |||
| """Document loader helpers.""" | |||
| import concurrent.futures | |||
| from pathlib import Path | |||
| from typing import NamedTuple, Optional, cast | |||
| @@ -28,8 +29,7 @@ def detect_file_encodings(file_path: str, timeout: int = 5) -> list[FileEncoding | |||
| import chardet | |||
| def read_and_detect(file_path: str) -> list[dict]: | |||
| with open(file_path, "rb") as f: | |||
| rawdata = f.read() | |||
| rawdata = Path(file_path).read_bytes() | |||
| return cast(list[dict], chardet.detect_all(rawdata)) | |||
| with concurrent.futures.ThreadPoolExecutor() as executor: | |||
| @@ -1,6 +1,7 @@ | |||
| """Abstract interface for document loader implementations.""" | |||
| import re | |||
| from pathlib import Path | |||
| from typing import Optional, cast | |||
| from core.rag.extractor.extractor_base import BaseExtractor | |||
| @@ -102,15 +103,13 @@ class MarkdownExtractor(BaseExtractor): | |||
| """Parse file into tuples.""" | |||
| content = "" | |||
| try: | |||
| with open(filepath, encoding=self._encoding) as f: | |||
| content = f.read() | |||
| content = Path(filepath).read_text(encoding=self._encoding) | |||
| except UnicodeDecodeError as e: | |||
| if self._autodetect_encoding: | |||
| detected_encodings = detect_file_encodings(filepath) | |||
| for encoding in detected_encodings: | |||
| try: | |||
| with open(filepath, encoding=encoding.encoding) as f: | |||
| content = f.read() | |||
| content = Path(filepath).read_text(encoding=encoding.encoding) | |||
| break | |||
| except UnicodeDecodeError: | |||
| continue | |||
| @@ -1,5 +1,6 @@ | |||
| """Abstract interface for document loader implementations.""" | |||
| from pathlib import Path | |||
| from typing import Optional | |||
| from core.rag.extractor.extractor_base import BaseExtractor | |||
| @@ -25,15 +26,13 @@ class TextExtractor(BaseExtractor): | |||
| """Load from file path.""" | |||
| text = "" | |||
| try: | |||
| with open(self._file_path, encoding=self._encoding) as f: | |||
| text = f.read() | |||
| text = Path(self._file_path).read_text(encoding=self._encoding) | |||
| except UnicodeDecodeError as e: | |||
| if self._autodetect_encoding: | |||
| detected_encodings = detect_file_encodings(self._file_path) | |||
| for encoding in detected_encodings: | |||
| try: | |||
| with open(self._file_path, encoding=encoding.encoding) as f: | |||
| text = f.read() | |||
| text = Path(self._file_path).read_text(encoding=encoding.encoding) | |||
| break | |||
| except UnicodeDecodeError: | |||
| continue | |||
| @@ -153,7 +153,7 @@ class WordExtractor(BaseExtractor): | |||
| if col_index >= total_cols: | |||
| break | |||
| cell_content = self._parse_cell(cell, image_map).strip() | |||
| cell_colspan = cell.grid_span if cell.grid_span else 1 | |||
| cell_colspan = cell.grid_span or 1 | |||
| for i in range(cell_colspan): | |||
| if col_index + i < total_cols: | |||
| row_cells[col_index + i] = cell_content if i == 0 else "" | |||
| @@ -256,7 +256,7 @@ class DatasetRetrieval: | |||
| # get retrieval model config | |||
| dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() | |||
| if dataset: | |||
| retrieval_model_config = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model | |||
| retrieval_model_config = dataset.retrieval_model or default_retrieval_model | |||
| # get top k | |||
| top_k = retrieval_model_config["top_k"] | |||
| @@ -410,7 +410,7 @@ class DatasetRetrieval: | |||
| return [] | |||
| # get retrieval model , if the model is not setting , using default | |||
| retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model | |||
| retrieval_model = dataset.retrieval_model or default_retrieval_model | |||
| if dataset.indexing_technique == "economy": | |||
| # use keyword table query | |||
| @@ -433,9 +433,7 @@ class DatasetRetrieval: | |||
| reranking_model=retrieval_model.get("reranking_model", None) | |||
| if retrieval_model["reranking_enable"] | |||
| else None, | |||
| reranking_mode=retrieval_model.get("reranking_mode") | |||
| if retrieval_model.get("reranking_mode") | |||
| else "reranking_model", | |||
| reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model", | |||
| weights=retrieval_model.get("weights", None), | |||
| ) | |||
| @@ -486,7 +484,7 @@ class DatasetRetrieval: | |||
| } | |||
| for dataset in available_datasets: | |||
| retrieval_model_config = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model | |||
| retrieval_model_config = dataset.retrieval_model or default_retrieval_model | |||
| # get top k | |||
| top_k = retrieval_model_config["top_k"] | |||
| @@ -106,7 +106,7 @@ class ApiToolProviderController(ToolProviderController): | |||
| "human": {"en_US": tool_bundle.summary or "", "zh_Hans": tool_bundle.summary or ""}, | |||
| "llm": tool_bundle.summary or "", | |||
| }, | |||
| "parameters": tool_bundle.parameters if tool_bundle.parameters else [], | |||
| "parameters": tool_bundle.parameters or [], | |||
| } | |||
| ) | |||
| @@ -1,4 +1,5 @@ | |||
| import json | |||
| import operator | |||
| from typing import Any, Union | |||
| import boto3 | |||
| @@ -71,7 +72,7 @@ class SageMakerReRankTool(BuiltinTool): | |||
| candidate_docs[idx]["score"] = scores[idx] | |||
| line = 8 | |||
| sorted_candidate_docs = sorted(candidate_docs, key=lambda x: x["score"], reverse=True) | |||
| sorted_candidate_docs = sorted(candidate_docs, key=operator.itemgetter("score"), reverse=True) | |||
| line = 9 | |||
| return [self.create_json_message(res) for res in sorted_candidate_docs[: self.topk]] | |||
| @@ -115,7 +115,7 @@ class GetWorksheetFieldsTool(BuiltinTool): | |||
| fields.append(field) | |||
| fields_list.append( | |||
| f"|{field['id']}|{field['name']}|{field['type']}|{field['typeId']}|{field['description']}" | |||
| f"|{field['options'] if field['options'] else ''}|" | |||
| f"|{field['options'] or ''}|" | |||
| ) | |||
| fields.append( | |||
| @@ -130,7 +130,7 @@ class GetWorksheetPivotDataTool(BuiltinTool): | |||
| # ] | |||
| rows = [] | |||
| for row in data["data"]: | |||
| row_data = row["rows"] if row["rows"] else {} | |||
| row_data = row["rows"] or {} | |||
| row_data.update(row["columns"]) | |||
| row_data.update(row["values"]) | |||
| rows.append(row_data) | |||
| @@ -113,7 +113,7 @@ class ListWorksheetRecordsTool(BuiltinTool): | |||
| result_text = f"Found {result['total']} rows in worksheet \"{worksheet_name}\"." | |||
| if result["total"] > 0: | |||
| result_text += ( | |||
| f" The following are {result['total'] if result['total'] < limit else limit}" | |||
| f" The following are {min(limit, result['total'])}" | |||
| f" pieces of data presented in a table format:\n\n{table_header}" | |||
| ) | |||
| for row in rows: | |||
| @@ -37,7 +37,7 @@ class SearchAPI: | |||
| return { | |||
| "engine": "youtube_transcripts", | |||
| "video_id": video_id, | |||
| "lang": language if language else "en", | |||
| "lang": language or "en", | |||
| **{key: value for key, value in kwargs.items() if value not in [None, ""]}, | |||
| } | |||
| @@ -160,7 +160,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool): | |||
| hit_callback.on_query(query, dataset.id) | |||
| # get retrieval model , if the model is not setting , using default | |||
| retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model | |||
| retrieval_model = dataset.retrieval_model or default_retrieval_model | |||
| if dataset.indexing_technique == "economy": | |||
| # use keyword table query | |||
| @@ -183,9 +183,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool): | |||
| reranking_model=retrieval_model.get("reranking_model", None) | |||
| if retrieval_model["reranking_enable"] | |||
| else None, | |||
| reranking_mode=retrieval_model.get("reranking_mode") | |||
| if retrieval_model.get("reranking_mode") | |||
| else "reranking_model", | |||
| reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model", | |||
| weights=retrieval_model.get("weights", None), | |||
| ) | |||
| @@ -55,7 +55,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): | |||
| hit_callback.on_query(query, dataset.id) | |||
| # get retrieval model , if the model is not setting , using default | |||
| retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model | |||
| retrieval_model = dataset.retrieval_model or default_retrieval_model | |||
| if dataset.indexing_technique == "economy": | |||
| # use keyword table query | |||
| documents = RetrievalService.retrieve( | |||
| @@ -76,9 +76,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): | |||
| reranking_model=retrieval_model.get("reranking_model", None) | |||
| if retrieval_model["reranking_enable"] | |||
| else None, | |||
| reranking_mode=retrieval_model.get("reranking_mode") | |||
| if retrieval_model.get("reranking_mode") | |||
| else "reranking_model", | |||
| reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model", | |||
| weights=retrieval_model.get("weights", None), | |||
| ) | |||
| else: | |||
| @@ -8,6 +8,7 @@ import subprocess | |||
| import tempfile | |||
| import unicodedata | |||
| from contextlib import contextmanager | |||
| from pathlib import Path | |||
| from urllib.parse import unquote | |||
| import chardet | |||
| @@ -98,7 +99,7 @@ def get_url(url: str, user_agent: str = None) -> str: | |||
| authors=a["byline"], | |||
| publish_date=a["date"], | |||
| top_image="", | |||
| text=a["plain_text"] if a["plain_text"] else "", | |||
| text=a["plain_text"] or "", | |||
| ) | |||
| return res | |||
| @@ -117,8 +118,7 @@ def extract_using_readabilipy(html): | |||
| subprocess.check_call(["node", "ExtractArticle.js", "-i", html_path, "-o", article_json_path]) | |||
| # Read output of call to Readability.parse() from JSON file and return as Python dictionary | |||
| with open(article_json_path, encoding="utf-8") as json_file: | |||
| input_json = json.loads(json_file.read()) | |||
| input_json = json.loads(Path(article_json_path).read_text(encoding="utf-8")) | |||
| # Deleting files after processing | |||
| os.unlink(article_json_path) | |||
| @@ -21,7 +21,7 @@ def load_yaml_file(file_path: str, ignore_error: bool = True, default_value: Any | |||
| with open(file_path, encoding="utf-8") as yaml_file: | |||
| try: | |||
| yaml_content = yaml.safe_load(yaml_file) | |||
| return yaml_content if yaml_content else default_value | |||
| return yaml_content or default_value | |||
| except Exception as e: | |||
| raise YAMLError(f"Failed to load YAML file {file_path}: {e}") | |||
| except Exception as e: | |||
| @@ -268,7 +268,7 @@ class Graph(BaseModel): | |||
| f"Node {graph_edge.source_node_id} is connected to the previous node, please check the graph." | |||
| ) | |||
| new_route = route[:] | |||
| new_route = route.copy() | |||
| new_route.append(graph_edge.target_node_id) | |||
| cls._check_connected_to_previous_node( | |||
| route=new_route, | |||
| @@ -679,8 +679,7 @@ class Graph(BaseModel): | |||
| all_routes_node_ids = set() | |||
| parallel_start_node_ids: dict[str, list[str]] = {} | |||
| for branch_node_id, node_ids in routes_node_ids.items(): | |||
| for node_id in node_ids: | |||
| all_routes_node_ids.add(node_id) | |||
| all_routes_node_ids.update(node_ids) | |||
| if branch_node_id in reverse_edge_mapping: | |||
| for graph_edge in reverse_edge_mapping[branch_node_id]: | |||
| @@ -74,7 +74,7 @@ class CodeNode(BaseNode): | |||
| :return: | |||
| """ | |||
| if not isinstance(value, str): | |||
| if isinstance(value, type(None)): | |||
| if value is None: | |||
| return None | |||
| else: | |||
| raise ValueError(f"Output variable `{variable}` must be a string") | |||
| @@ -95,7 +95,7 @@ class CodeNode(BaseNode): | |||
| :return: | |||
| """ | |||
| if not isinstance(value, int | float): | |||
| if isinstance(value, type(None)): | |||
| if value is None: | |||
| return None | |||
| else: | |||
| raise ValueError(f"Output variable `{variable}` must be a number") | |||
| @@ -182,7 +182,7 @@ class CodeNode(BaseNode): | |||
| f"Output {prefix}.{output_name} is not a valid array." | |||
| f" make sure all elements are of the same type." | |||
| ) | |||
| elif isinstance(output_value, type(None)): | |||
| elif output_value is None: | |||
| pass | |||
| else: | |||
| raise ValueError(f"Output {prefix}.{output_name} is not a valid type.") | |||
| @@ -284,7 +284,7 @@ class CodeNode(BaseNode): | |||
| for i, value in enumerate(result[output_name]): | |||
| if not isinstance(value, dict): | |||
| if isinstance(value, type(None)): | |||
| if value is None: | |||
| pass | |||
| else: | |||
| raise ValueError( | |||
| @@ -79,7 +79,7 @@ class IfElseNode(BaseNode): | |||
| status=WorkflowNodeExecutionStatus.SUCCEEDED, | |||
| inputs=node_inputs, | |||
| process_data=process_datas, | |||
| edge_source_handle=selected_case_id if selected_case_id else "false", # Use case ID or 'default' | |||
| edge_source_handle=selected_case_id or "false", # Use case ID or 'default' | |||
| outputs=outputs, | |||
| ) | |||
| @@ -580,7 +580,7 @@ class LLMNode(BaseNode): | |||
| prompt_messages = prompt_transform.get_prompt( | |||
| prompt_template=node_data.prompt_template, | |||
| inputs=inputs, | |||
| query=query if query else "", | |||
| query=query or "", | |||
| files=files, | |||
| context=context, | |||
| memory_config=node_data.memory, | |||
| @@ -250,7 +250,7 @@ class QuestionClassifierNode(LLMNode): | |||
| for class_ in classes: | |||
| category = {"category_id": class_.id, "category_name": class_.name} | |||
| categories.append(category) | |||
| instruction = node_data.instruction if node_data.instruction else "" | |||
| instruction = node_data.instruction or "" | |||
| input_text = query | |||
| memory_str = "" | |||
| if memory: | |||
| @@ -18,8 +18,7 @@ def handle(sender, **kwargs): | |||
| added_dataset_ids = dataset_ids | |||
| else: | |||
| old_dataset_ids = set() | |||
| for app_dataset_join in app_dataset_joins: | |||
| old_dataset_ids.add(app_dataset_join.dataset_id) | |||
| old_dataset_ids.update(app_dataset_join.dataset_id for app_dataset_join in app_dataset_joins) | |||
| added_dataset_ids = dataset_ids - old_dataset_ids | |||
| removed_dataset_ids = old_dataset_ids - dataset_ids | |||
| @@ -22,8 +22,7 @@ def handle(sender, **kwargs): | |||
| added_dataset_ids = dataset_ids | |||
| else: | |||
| old_dataset_ids = set() | |||
| for app_dataset_join in app_dataset_joins: | |||
| old_dataset_ids.add(app_dataset_join.dataset_id) | |||
| old_dataset_ids.update(app_dataset_join.dataset_id for app_dataset_join in app_dataset_joins) | |||
| added_dataset_ids = dataset_ids - old_dataset_ids | |||
| removed_dataset_ids = old_dataset_ids - dataset_ids | |||
| @@ -1,6 +1,7 @@ | |||
| import os | |||
| import shutil | |||
| from collections.abc import Generator | |||
| from pathlib import Path | |||
| from flask import Flask | |||
| @@ -26,8 +27,7 @@ class LocalStorage(BaseStorage): | |||
| folder = os.path.dirname(filename) | |||
| os.makedirs(folder, exist_ok=True) | |||
| with open(os.path.join(os.getcwd(), filename), "wb") as f: | |||
| f.write(data) | |||
| Path(os.path.join(os.getcwd(), filename)).write_bytes(data) | |||
| def load_once(self, filename: str) -> bytes: | |||
| if not self.folder or self.folder.endswith("/"): | |||
| @@ -38,9 +38,7 @@ class LocalStorage(BaseStorage): | |||
| if not os.path.exists(filename): | |||
| raise FileNotFoundError("File not found") | |||
| with open(filename, "rb") as f: | |||
| data = f.read() | |||
| data = Path(filename).read_bytes() | |||
| return data | |||
| def load_stream(self, filename: str) -> Generator: | |||
| @@ -144,7 +144,7 @@ class Dataset(db.Model): | |||
| "top_k": 2, | |||
| "score_threshold_enabled": False, | |||
| } | |||
| return self.retrieval_model if self.retrieval_model else default_retrieval_model | |||
| return self.retrieval_model or default_retrieval_model | |||
| @property | |||
| def tags(self): | |||
| @@ -160,7 +160,7 @@ class Dataset(db.Model): | |||
| .all() | |||
| ) | |||
| return tags if tags else [] | |||
| return tags or [] | |||
| @staticmethod | |||
| def gen_collection_name_by_id(dataset_id: str) -> str: | |||
| @@ -118,7 +118,7 @@ class App(db.Model): | |||
| @property | |||
| def api_base_url(self): | |||
| return (dify_config.SERVICE_API_URL if dify_config.SERVICE_API_URL else request.host_url.rstrip("/")) + "/v1" | |||
| return (dify_config.SERVICE_API_URL or request.host_url.rstrip("/")) + "/v1" | |||
| @property | |||
| def tenant(self): | |||
| @@ -207,7 +207,7 @@ class App(db.Model): | |||
| .all() | |||
| ) | |||
| return tags if tags else [] | |||
| return tags or [] | |||
| class AppModelConfig(db.Model): | |||
| @@ -908,7 +908,7 @@ class Message(db.Model): | |||
| "id": message_file.id, | |||
| "type": message_file.type, | |||
| "url": url, | |||
| "belongs_to": message_file.belongs_to if message_file.belongs_to else "user", | |||
| "belongs_to": message_file.belongs_to or "user", | |||
| } | |||
| ) | |||
| @@ -1212,7 +1212,7 @@ class Site(db.Model): | |||
| @property | |||
| def app_base_url(self): | |||
| return dify_config.APP_WEB_URL if dify_config.APP_WEB_URL else request.url_root.rstrip("/") | |||
| return dify_config.APP_WEB_URL or request.url_root.rstrip("/") | |||
| class ApiToken(db.Model): | |||
| @@ -1488,7 +1488,7 @@ class TraceAppConfig(db.Model): | |||
| @property | |||
| def tracing_config_dict(self): | |||
| return self.tracing_config if self.tracing_config else {} | |||
| return self.tracing_config or {} | |||
| @property | |||
| def tracing_config_str(self): | |||
| @@ -15,6 +15,7 @@ select = [ | |||
| "C4", # flake8-comprehensions | |||
| "E", # pycodestyle E rules | |||
| "F", # pyflakes rules | |||
| "FURB", # refurb rules | |||
| "I", # isort rules | |||
| "N", # pep8-naming | |||
| "RUF019", # unnecessary-key-check | |||
| @@ -37,6 +38,8 @@ ignore = [ | |||
| "F405", # undefined-local-with-import-star-usage | |||
| "F821", # undefined-name | |||
| "F841", # unused-variable | |||
| "FURB113", # repeated-append | |||
| "FURB152", # math-constant | |||
| "UP007", # non-pep604-annotation | |||
| "UP032", # f-string | |||
| "B005", # strip-with-multi-characters | |||
| @@ -544,7 +544,7 @@ class RegisterService: | |||
| """Register account""" | |||
| try: | |||
| account = AccountService.create_account( | |||
| email=email, name=name, interface_language=language if language else languages[0], password=password | |||
| email=email, name=name, interface_language=language or languages[0], password=password | |||
| ) | |||
| account.status = AccountStatus.ACTIVE.value if not status else status.value | |||
| account.initialized_at = datetime.now(timezone.utc).replace(tzinfo=None) | |||
| @@ -81,13 +81,11 @@ class AppDslService: | |||
| raise ValueError("Missing app in data argument") | |||
| # get app basic info | |||
| name = args.get("name") if args.get("name") else app_data.get("name") | |||
| description = args.get("description") if args.get("description") else app_data.get("description", "") | |||
| icon_type = args.get("icon_type") if args.get("icon_type") else app_data.get("icon_type") | |||
| icon = args.get("icon") if args.get("icon") else app_data.get("icon") | |||
| icon_background = ( | |||
| args.get("icon_background") if args.get("icon_background") else app_data.get("icon_background") | |||
| ) | |||
| name = args.get("name") or app_data.get("name") | |||
| description = args.get("description") or app_data.get("description", "") | |||
| icon_type = args.get("icon_type") or app_data.get("icon_type") | |||
| icon = args.get("icon") or app_data.get("icon") | |||
| icon_background = args.get("icon_background") or app_data.get("icon_background") | |||
| use_icon_as_answer_icon = app_data.get("use_icon_as_answer_icon", False) | |||
| # import dsl and create app | |||
| @@ -155,7 +155,7 @@ class DatasetService: | |||
| dataset.tenant_id = tenant_id | |||
| dataset.embedding_model_provider = embedding_model.provider if embedding_model else None | |||
| dataset.embedding_model = embedding_model.model if embedding_model else None | |||
| dataset.permission = permission if permission else DatasetPermissionEnum.ONLY_ME | |||
| dataset.permission = permission or DatasetPermissionEnum.ONLY_ME | |||
| db.session.add(dataset) | |||
| db.session.commit() | |||
| return dataset | |||
| @@ -681,11 +681,7 @@ class DocumentService: | |||
| "score_threshold_enabled": False, | |||
| } | |||
| dataset.retrieval_model = ( | |||
| document_data.get("retrieval_model") | |||
| if document_data.get("retrieval_model") | |||
| else default_retrieval_model | |||
| ) | |||
| dataset.retrieval_model = document_data.get("retrieval_model") or default_retrieval_model | |||
| documents = [] | |||
| batch = time.strftime("%Y%m%d%H%M%S") + str(random.randint(100000, 999999)) | |||
| @@ -33,7 +33,7 @@ class HitTestingService: | |||
| # get retrieval model , if the model is not setting , using default | |||
| if not retrieval_model: | |||
| retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model | |||
| retrieval_model = dataset.retrieval_model or default_retrieval_model | |||
| all_documents = RetrievalService.retrieve( | |||
| retrieval_method=retrieval_model.get("search_method", "semantic_search"), | |||
| @@ -46,9 +46,7 @@ class HitTestingService: | |||
| reranking_model=retrieval_model.get("reranking_model", None) | |||
| if retrieval_model["reranking_enable"] | |||
| else None, | |||
| reranking_mode=retrieval_model.get("reranking_mode") | |||
| if retrieval_model.get("reranking_mode") | |||
| else "reranking_model", | |||
| reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model", | |||
| weights=retrieval_model.get("weights", None), | |||
| ) | |||
| @@ -1,6 +1,7 @@ | |||
| import logging | |||
| import mimetypes | |||
| import os | |||
| from pathlib import Path | |||
| from typing import Optional, cast | |||
| import requests | |||
| @@ -453,9 +454,8 @@ class ModelProviderService: | |||
| mimetype = mimetype or "application/octet-stream" | |||
| # read binary from file | |||
| with open(file_path, "rb") as f: | |||
| byte_data = f.read() | |||
| return byte_data, mimetype | |||
| byte_data = Path(file_path).read_bytes() | |||
| return byte_data, mimetype | |||
| def switch_preferred_provider(self, tenant_id: str, provider: str, preferred_provider_type: str) -> None: | |||
| """ | |||