Kaynağa Gözat

chore: refurbish Python code by applying refurb linter rules (#8296)

tags/0.8.1
Bowen Liang 1 yıl önce
ebeveyn
işleme
40fb4d16ef
No account linked to committer's email address
100 değiştirilmiş dosya ile 206 ekleme ve 262 silme
  1. 8
    16
      api/controllers/console/admin.py
  2. 2
    6
      api/controllers/console/app/audio.py
  3. 1
    1
      api/controllers/console/auth/oauth.py
  4. 1
    6
      api/controllers/console/datasets/datasets.py
  5. 2
    6
      api/controllers/console/explore/audio.py
  6. 1
    1
      api/controllers/console/workspace/tool_providers.py
  7. 2
    6
      api/controllers/service_api/app/audio.py
  8. 2
    6
      api/controllers/web/audio.py
  9. 1
    1
      api/core/agent/cot_agent_runner.py
  10. 1
    1
      api/core/agent/fc_agent_runner.py
  11. 4
    4
      api/core/app/apps/base_app_runner.py
  12. 2
    2
      api/core/extension/extensible.py
  13. 1
    1
      api/core/memory/token_buffer_memory.py
  14. 1
    1
      api/core/model_runtime/model_providers/__base/large_language_model.py
  15. 1
    1
      api/core/model_runtime/model_providers/anthropic/llm/llm.py
  16. 1
    1
      api/core/model_runtime/model_providers/azure_ai_studio/llm/llm.py
  17. 4
    6
      api/core/model_runtime/model_providers/azure_openai/llm/llm.py
  18. 1
    1
      api/core/model_runtime/model_providers/azure_openai/tts/tts.py
  19. 4
    4
      api/core/model_runtime/model_providers/bedrock/llm/llm.py
  20. 2
    2
      api/core/model_runtime/model_providers/chatglm/llm/llm.py
  21. 3
    3
      api/core/model_runtime/model_providers/localai/llm/llm.py
  22. 2
    2
      api/core/model_runtime/model_providers/minimax/llm/llm.py
  23. 1
    1
      api/core/model_runtime/model_providers/ollama/text_embedding/text_embedding.py
  24. 3
    5
      api/core/model_runtime/model_providers/openai/llm/llm.py
  25. 1
    1
      api/core/model_runtime/model_providers/openai/tts/tts.py
  26. 2
    2
      api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py
  27. 2
    2
      api/core/model_runtime/model_providers/openllm/llm/llm.py
  28. 1
    1
      api/core/model_runtime/model_providers/openllm/llm/openllm_generate.py
  29. 1
    1
      api/core/model_runtime/model_providers/replicate/llm/llm.py
  30. 2
    1
      api/core/model_runtime/model_providers/sagemaker/rerank/rerank.py
  31. 1
    1
      api/core/model_runtime/model_providers/sagemaker/tts/tts.py
  32. 1
    1
      api/core/model_runtime/model_providers/spark/llm/llm.py
  33. 2
    1
      api/core/model_runtime/model_providers/tencent/speech2text/flash_recognizer.py
  34. 2
    2
      api/core/model_runtime/model_providers/tongyi/llm/llm.py
  35. 2
    4
      api/core/model_runtime/model_providers/upstage/llm/llm.py
  36. 2
    2
      api/core/model_runtime/model_providers/vertex_ai/llm/llm.py
  37. 2
    1
      api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/base/auth.py
  38. 2
    1
      api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/base/util.py
  39. 3
    5
      api/core/model_runtime/model_providers/volcengine_maas/llm/llm.py
  40. 3
    3
      api/core/model_runtime/model_providers/wenxin/llm/llm.py
  41. 1
    1
      api/core/model_runtime/model_providers/wenxin/text_embedding/text_embedding.py
  42. 3
    3
      api/core/model_runtime/model_providers/xinference/llm/llm.py
  43. 1
    1
      api/core/model_runtime/model_providers/xinference/tts/tts.py
  44. 2
    2
      api/core/model_runtime/model_providers/zhipuai/llm/llm.py
  45. 3
    1
      api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_http_client.py
  46. 10
    12
      api/core/ops/langfuse_trace/langfuse_trace.py
  47. 4
    6
      api/core/ops/langsmith_trace/langsmith_trace.py
  48. 7
    8
      api/core/ops/ops_trace_manager.py
  49. 3
    3
      api/core/prompt/simple_prompt_transform.py
  50. 1
    1
      api/core/rag/datasource/keyword/keyword_base.py
  51. 2
    2
      api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py
  52. 1
    1
      api/core/rag/datasource/vdb/chroma/chroma_vector.py
  53. 3
    3
      api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py
  54. 1
    1
      api/core/rag/datasource/vdb/milvus/milvus_vector.py
  55. 1
    1
      api/core/rag/datasource/vdb/myscale/myscale_vector.py
  56. 1
    1
      api/core/rag/datasource/vdb/opensearch/opensearch_vector.py
  57. 2
    2
      api/core/rag/datasource/vdb/oracle/oraclevector.py
  58. 1
    1
      api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py
  59. 1
    1
      api/core/rag/datasource/vdb/pgvector/pgvector.py
  60. 1
    1
      api/core/rag/datasource/vdb/qdrant/qdrant_vector.py
  61. 1
    1
      api/core/rag/datasource/vdb/relyt/relyt_vector.py
  62. 1
    1
      api/core/rag/datasource/vdb/tencent/tencent_vector.py
  63. 1
    1
      api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py
  64. 1
    1
      api/core/rag/datasource/vdb/vector_base.py
  65. 1
    1
      api/core/rag/datasource/vdb/vector_factory.py
  66. 1
    1
      api/core/rag/datasource/vdb/weaviate/weaviate_vector.py
  67. 3
    5
      api/core/rag/extractor/blob/blob.py
  68. 3
    4
      api/core/rag/extractor/extract_processor.py
  69. 2
    2
      api/core/rag/extractor/helpers.py
  70. 3
    4
      api/core/rag/extractor/markdown_extractor.py
  71. 3
    4
      api/core/rag/extractor/text_extractor.py
  72. 1
    1
      api/core/rag/extractor/word_extractor.py
  73. 4
    6
      api/core/rag/retrieval/dataset_retrieval.py
  74. 1
    1
      api/core/tools/provider/api_tool_provider.py
  75. 2
    1
      api/core/tools/provider/builtin/aws/tools/sagemaker_text_rerank.py
  76. 1
    1
      api/core/tools/provider/builtin/hap/tools/get_worksheet_fields.py
  77. 1
    1
      api/core/tools/provider/builtin/hap/tools/get_worksheet_pivot_data.py
  78. 1
    1
      api/core/tools/provider/builtin/hap/tools/list_worksheet_records.py
  79. 1
    1
      api/core/tools/provider/builtin/searchapi/tools/youtube_transcripts.py
  80. 2
    4
      api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py
  81. 2
    4
      api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py
  82. 3
    3
      api/core/tools/utils/web_reader_tool.py
  83. 1
    1
      api/core/tools/utils/yaml_utils.py
  84. 2
    3
      api/core/workflow/graph_engine/entities/graph.py
  85. 4
    4
      api/core/workflow/nodes/code/code_node.py
  86. 1
    1
      api/core/workflow/nodes/if_else/if_else_node.py
  87. 1
    1
      api/core/workflow/nodes/llm/llm_node.py
  88. 1
    1
      api/core/workflow/nodes/question_classifier/question_classifier_node.py
  89. 1
    2
      api/events/event_handlers/update_app_dataset_join_when_app_model_config_updated.py
  90. 1
    2
      api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py
  91. 3
    5
      api/extensions/storage/local_storage.py
  92. 2
    2
      api/models/dataset.py
  93. 5
    5
      api/models/model.py
  94. 3
    0
      api/pyproject.toml
  95. 1
    1
      api/services/account_service.py
  96. 5
    7
      api/services/app_dsl_service.py
  97. 2
    6
      api/services/dataset_service.py
  98. 2
    4
      api/services/hit_testing_service.py
  99. 3
    3
      api/services/model_provider_service.py
  100. 0
    0
      api/services/recommended_app_service.py

+ 8
- 16
api/controllers/console/admin.py Dosyayı Görüntüle

@@ -60,23 +60,15 @@ class InsertExploreAppListApi(Resource):

site = app.site
if not site:
desc = args["desc"] if args["desc"] else ""
copy_right = args["copyright"] if args["copyright"] else ""
privacy_policy = args["privacy_policy"] if args["privacy_policy"] else ""
custom_disclaimer = args["custom_disclaimer"] if args["custom_disclaimer"] else ""
desc = args["desc"] or ""
copy_right = args["copyright"] or ""
privacy_policy = args["privacy_policy"] or ""
custom_disclaimer = args["custom_disclaimer"] or ""
else:
desc = site.description if site.description else args["desc"] if args["desc"] else ""
copy_right = site.copyright if site.copyright else args["copyright"] if args["copyright"] else ""
privacy_policy = (
site.privacy_policy if site.privacy_policy else args["privacy_policy"] if args["privacy_policy"] else ""
)
custom_disclaimer = (
site.custom_disclaimer
if site.custom_disclaimer
else args["custom_disclaimer"]
if args["custom_disclaimer"]
else ""
)
desc = site.description or args["desc"] or ""
copy_right = site.copyright or args["copyright"] or ""
privacy_policy = site.privacy_policy or args["privacy_policy"] or ""
custom_disclaimer = site.custom_disclaimer or args["custom_disclaimer"] or ""

recommended_app = RecommendedApp.query.filter(RecommendedApp.app_id == args["app_id"]).first()


+ 2
- 6
api/controllers/console/app/audio.py Dosyayı Görüntüle

@@ -99,14 +99,10 @@ class ChatMessageTextApi(Resource):
and app_model.workflow.features_dict
):
text_to_speech = app_model.workflow.features_dict.get("text_to_speech")
voice = args.get("voice") if args.get("voice") else text_to_speech.get("voice")
voice = args.get("voice") or text_to_speech.get("voice")
else:
try:
voice = (
args.get("voice")
if args.get("voice")
else app_model.app_model_config.text_to_speech_dict.get("voice")
)
voice = args.get("voice") or app_model.app_model_config.text_to_speech_dict.get("voice")
except Exception:
voice = None
response = AudioService.transcript_tts(app_model=app_model, text=text, message_id=message_id, voice=voice)

+ 1
- 1
api/controllers/console/auth/oauth.py Dosyayı Görüntüle

@@ -101,7 +101,7 @@ def _generate_account(provider: str, user_info: OAuthUserInfo):

if not account:
# Create account
account_name = user_info.name if user_info.name else "Dify"
account_name = user_info.name or "Dify"
account = RegisterService.register(
email=user_info.email, name=account_name, password=None, open_id=user_info.id, provider=provider
)

+ 1
- 6
api/controllers/console/datasets/datasets.py Dosyayı Görüntüle

@@ -550,12 +550,7 @@ class DatasetApiBaseUrlApi(Resource):
@login_required
@account_initialization_required
def get(self):
return {
"api_base_url": (
dify_config.SERVICE_API_URL if dify_config.SERVICE_API_URL else request.host_url.rstrip("/")
)
+ "/v1"
}
return {"api_base_url": (dify_config.SERVICE_API_URL or request.host_url.rstrip("/")) + "/v1"}


class DatasetRetrievalSettingApi(Resource):

+ 2
- 6
api/controllers/console/explore/audio.py Dosyayı Görüntüle

@@ -86,14 +86,10 @@ class ChatTextApi(InstalledAppResource):
and app_model.workflow.features_dict
):
text_to_speech = app_model.workflow.features_dict.get("text_to_speech")
voice = args.get("voice") if args.get("voice") else text_to_speech.get("voice")
voice = args.get("voice") or text_to_speech.get("voice")
else:
try:
voice = (
args.get("voice")
if args.get("voice")
else app_model.app_model_config.text_to_speech_dict.get("voice")
)
voice = args.get("voice") or app_model.app_model_config.text_to_speech_dict.get("voice")
except Exception:
voice = None
response = AudioService.transcript_tts(app_model=app_model, message_id=message_id, voice=voice, text=text)

+ 1
- 1
api/controllers/console/workspace/tool_providers.py Dosyayı Görüntüle

@@ -327,7 +327,7 @@ class ToolApiProviderPreviousTestApi(Resource):

return ApiToolManageService.test_api_tool_preview(
current_user.current_tenant_id,
args["provider_name"] if args["provider_name"] else "",
args["provider_name"] or "",
args["tool_name"],
args["credentials"],
args["parameters"],

+ 2
- 6
api/controllers/service_api/app/audio.py Dosyayı Görüntüle

@@ -84,14 +84,10 @@ class TextApi(Resource):
and app_model.workflow.features_dict
):
text_to_speech = app_model.workflow.features_dict.get("text_to_speech")
voice = args.get("voice") if args.get("voice") else text_to_speech.get("voice")
voice = args.get("voice") or text_to_speech.get("voice")
else:
try:
voice = (
args.get("voice")
if args.get("voice")
else app_model.app_model_config.text_to_speech_dict.get("voice")
)
voice = args.get("voice") or app_model.app_model_config.text_to_speech_dict.get("voice")
except Exception:
voice = None
response = AudioService.transcript_tts(

+ 2
- 6
api/controllers/web/audio.py Dosyayı Görüntüle

@@ -83,14 +83,10 @@ class TextApi(WebApiResource):
and app_model.workflow.features_dict
):
text_to_speech = app_model.workflow.features_dict.get("text_to_speech")
voice = args.get("voice") if args.get("voice") else text_to_speech.get("voice")
voice = args.get("voice") or text_to_speech.get("voice")
else:
try:
voice = (
args.get("voice")
if args.get("voice")
else app_model.app_model_config.text_to_speech_dict.get("voice")
)
voice = args.get("voice") or app_model.app_model_config.text_to_speech_dict.get("voice")
except Exception:
voice = None


+ 1
- 1
api/core/agent/cot_agent_runner.py Dosyayı Görüntüle

@@ -256,7 +256,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
model=model_instance.model,
prompt_messages=prompt_messages,
message=AssistantPromptMessage(content=final_answer),
usage=llm_usage["usage"] if llm_usage["usage"] else LLMUsage.empty_usage(),
usage=llm_usage["usage"] or LLMUsage.empty_usage(),
system_fingerprint="",
)
),

+ 1
- 1
api/core/agent/fc_agent_runner.py Dosyayı Görüntüle

@@ -298,7 +298,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
model=model_instance.model,
prompt_messages=prompt_messages,
message=AssistantPromptMessage(content=final_answer),
usage=llm_usage["usage"] if llm_usage["usage"] else LLMUsage.empty_usage(),
usage=llm_usage["usage"] or LLMUsage.empty_usage(),
system_fingerprint="",
)
),

+ 4
- 4
api/core/app/apps/base_app_runner.py Dosyayı Görüntüle

@@ -161,7 +161,7 @@ class AppRunner:
app_mode=AppMode.value_of(app_record.mode),
prompt_template_entity=prompt_template_entity,
inputs=inputs,
query=query if query else "",
query=query or "",
files=files,
context=context,
memory=memory,
@@ -189,7 +189,7 @@ class AppRunner:
prompt_messages = prompt_transform.get_prompt(
prompt_template=prompt_template,
inputs=inputs,
query=query if query else "",
query=query or "",
files=files,
context=context,
memory_config=memory_config,
@@ -238,7 +238,7 @@ class AppRunner:
model=app_generate_entity.model_conf.model,
prompt_messages=prompt_messages,
message=AssistantPromptMessage(content=text),
usage=usage if usage else LLMUsage.empty_usage(),
usage=usage or LLMUsage.empty_usage(),
),
),
PublishFrom.APPLICATION_MANAGER,
@@ -351,7 +351,7 @@ class AppRunner:
tenant_id=tenant_id,
app_config=app_generate_entity.app_config,
inputs=inputs,
query=query if query else "",
query=query or "",
message_id=message_id,
trace_manager=app_generate_entity.trace_manager,
)

+ 2
- 2
api/core/extension/extensible.py Dosyayı Görüntüle

@@ -3,6 +3,7 @@ import importlib.util
import json
import logging
import os
from pathlib import Path
from typing import Any, Optional

from pydantic import BaseModel
@@ -63,8 +64,7 @@ class Extensible:

builtin_file_path = os.path.join(subdir_path, "__builtin__")
if os.path.exists(builtin_file_path):
with open(builtin_file_path, encoding="utf-8") as f:
position = int(f.read().strip())
position = int(Path(builtin_file_path).read_text(encoding="utf-8").strip())
position_map[extension_name] = position

if (extension_name + ".py") not in file_names:

+ 1
- 1
api/core/memory/token_buffer_memory.py Dosyayı Görüntüle

@@ -39,7 +39,7 @@ class TokenBufferMemory:
)

if message_limit and message_limit > 0:
message_limit = message_limit if message_limit <= 500 else 500
message_limit = min(message_limit, 500)
else:
message_limit = 500


+ 1
- 1
api/core/model_runtime/model_providers/__base/large_language_model.py Dosyayı Görüntüle

@@ -449,7 +449,7 @@ if you are not sure about the structure.
model=real_model,
prompt_messages=prompt_messages,
message=prompt_message,
usage=usage if usage else LLMUsage.empty_usage(),
usage=usage or LLMUsage.empty_usage(),
system_fingerprint=system_fingerprint,
),
credentials=credentials,

+ 1
- 1
api/core/model_runtime/model_providers/anthropic/llm/llm.py Dosyayı Görüntüle

@@ -409,7 +409,7 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
),
)
elif isinstance(chunk, ContentBlockDeltaEvent):
chunk_text = chunk.delta.text if chunk.delta.text else ""
chunk_text = chunk.delta.text or ""
full_assistant_content += chunk_text

# transform assistant message to prompt message

+ 1
- 1
api/core/model_runtime/model_providers/azure_ai_studio/llm/llm.py Dosyayı Görüntüle

@@ -213,7 +213,7 @@ class AzureAIStudioLargeLanguageModel(LargeLanguageModel):
model=real_model,
prompt_messages=prompt_messages,
message=prompt_message,
usage=usage if usage else LLMUsage.empty_usage(),
usage=usage or LLMUsage.empty_usage(),
system_fingerprint=system_fingerprint,
),
credentials=credentials,

+ 4
- 6
api/core/model_runtime/model_providers/azure_openai/llm/llm.py Dosyayı Görüntüle

@@ -225,7 +225,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
continue

# transform assistant message to prompt message
text = delta.text if delta.text else ""
text = delta.text or ""
assistant_prompt_message = AssistantPromptMessage(content=text)

full_text += text
@@ -400,15 +400,13 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
continue

# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(
content=delta.delta.content if delta.delta.content else "", tool_calls=tool_calls
)
assistant_prompt_message = AssistantPromptMessage(content=delta.delta.content or "", tool_calls=tool_calls)

full_assistant_content += delta.delta.content if delta.delta.content else ""
full_assistant_content += delta.delta.content or ""

real_model = chunk.model
system_fingerprint = chunk.system_fingerprint
completion += delta.delta.content if delta.delta.content else ""
completion += delta.delta.content or ""

yield LLMResultChunk(
model=real_model,

+ 1
- 1
api/core/model_runtime/model_providers/azure_openai/tts/tts.py Dosyayı Görüntüle

@@ -84,7 +84,7 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel):
)
for i in range(len(sentences))
]
for index, future in enumerate(futures):
for future in futures:
yield from future.result().__enter__().iter_bytes(1024)

else:

+ 4
- 4
api/core/model_runtime/model_providers/bedrock/llm/llm.py Dosyayı Görüntüle

@@ -331,10 +331,10 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
elif "contentBlockDelta" in chunk:
delta = chunk["contentBlockDelta"]["delta"]
if "text" in delta:
chunk_text = delta["text"] if delta["text"] else ""
chunk_text = delta["text"] or ""
full_assistant_content += chunk_text
assistant_prompt_message = AssistantPromptMessage(
content=chunk_text if chunk_text else "",
content=chunk_text or "",
)
index = chunk["contentBlockDelta"]["contentBlockIndex"]
yield LLMResultChunk(
@@ -751,7 +751,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
elif model_prefix == "cohere":
output = response_body.get("generations")[0].get("text")
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
completion_tokens = self.get_num_tokens(model, credentials, output if output else "")
completion_tokens = self.get_num_tokens(model, credentials, output or "")

else:
raise ValueError(f"Got unknown model prefix {model_prefix} when handling block response")
@@ -828,7 +828,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel):

# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(
content=content_delta if content_delta else "",
content=content_delta or "",
)
index += 1


+ 2
- 2
api/core/model_runtime/model_providers/chatglm/llm/llm.py Dosyayı Görüntüle

@@ -302,11 +302,11 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel):
if delta.delta.function_call:
function_calls = [delta.delta.function_call]

assistant_message_tool_calls = self._extract_response_tool_calls(function_calls if function_calls else [])
assistant_message_tool_calls = self._extract_response_tool_calls(function_calls or [])

# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(
content=delta.delta.content if delta.delta.content else "", tool_calls=assistant_message_tool_calls
content=delta.delta.content or "", tool_calls=assistant_message_tool_calls
)

if delta.finish_reason is not None:

+ 3
- 3
api/core/model_runtime/model_providers/localai/llm/llm.py Dosyayı Görüntüle

@@ -511,7 +511,7 @@ class LocalAILanguageModel(LargeLanguageModel):
delta = chunk.choices[0]

# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(content=delta.text if delta.text else "", tool_calls=[])
assistant_prompt_message = AssistantPromptMessage(content=delta.text or "", tool_calls=[])

if delta.finish_reason is not None:
# temp_assistant_prompt_message is used to calculate usage
@@ -578,11 +578,11 @@ class LocalAILanguageModel(LargeLanguageModel):
if delta.delta.function_call:
function_calls = [delta.delta.function_call]

assistant_message_tool_calls = self._extract_response_tool_calls(function_calls if function_calls else [])
assistant_message_tool_calls = self._extract_response_tool_calls(function_calls or [])

# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(
content=delta.delta.content if delta.delta.content else "", tool_calls=assistant_message_tool_calls
content=delta.delta.content or "", tool_calls=assistant_message_tool_calls
)

if delta.finish_reason is not None:

+ 2
- 2
api/core/model_runtime/model_providers/minimax/llm/llm.py Dosyayı Görüntüle

@@ -211,7 +211,7 @@ class MinimaxLargeLanguageModel(LargeLanguageModel):
index=0,
message=AssistantPromptMessage(content=message.content, tool_calls=[]),
usage=usage,
finish_reason=message.stop_reason if message.stop_reason else None,
finish_reason=message.stop_reason or None,
),
)
elif message.function_call:
@@ -244,7 +244,7 @@ class MinimaxLargeLanguageModel(LargeLanguageModel):
delta=LLMResultChunkDelta(
index=0,
message=AssistantPromptMessage(content=message.content, tool_calls=[]),
finish_reason=message.stop_reason if message.stop_reason else None,
finish_reason=message.stop_reason or None,
),
)


+ 1
- 1
api/core/model_runtime/model_providers/ollama/text_embedding/text_embedding.py Dosyayı Görüntüle

@@ -65,7 +65,7 @@ class OllamaEmbeddingModel(TextEmbeddingModel):
inputs = []
used_tokens = 0

for i, text in enumerate(texts):
for text in texts:
# Here token count is only an approximation based on the GPT2 tokenizer
num_tokens = self._get_num_tokens_by_gpt2(text)


+ 3
- 5
api/core/model_runtime/model_providers/openai/llm/llm.py Dosyayı Görüntüle

@@ -508,7 +508,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
continue

# transform assistant message to prompt message
text = delta.text if delta.text else ""
text = delta.text or ""
assistant_prompt_message = AssistantPromptMessage(content=text)

full_text += text
@@ -760,11 +760,9 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
final_tool_calls.extend(tool_calls)

# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(
content=delta.delta.content if delta.delta.content else "", tool_calls=tool_calls
)
assistant_prompt_message = AssistantPromptMessage(content=delta.delta.content or "", tool_calls=tool_calls)

full_assistant_content += delta.delta.content if delta.delta.content else ""
full_assistant_content += delta.delta.content or ""

if has_finish_reason:
final_chunk = LLMResultChunk(

+ 1
- 1
api/core/model_runtime/model_providers/openai/tts/tts.py Dosyayı Görüntüle

@@ -88,7 +88,7 @@ class OpenAIText2SpeechModel(_CommonOpenAI, TTSModel):
)
for i in range(len(sentences))
]
for index, future in enumerate(futures):
for future in futures:
yield from future.result().__enter__().iter_bytes(1024)

else:

+ 2
- 2
api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py Dosyayı Görüntüle

@@ -179,9 +179,9 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel):
features = []

function_calling_type = credentials.get("function_calling_type", "no_call")
if function_calling_type in ["function_call"]:
if function_calling_type == "function_call":
features.append(ModelFeature.TOOL_CALL)
elif function_calling_type in ["tool_call"]:
elif function_calling_type == "tool_call":
features.append(ModelFeature.MULTI_TOOL_CALL)

stream_function_calling = credentials.get("stream_function_calling", "supported")

+ 2
- 2
api/core/model_runtime/model_providers/openllm/llm/llm.py Dosyayı Görüntüle

@@ -179,7 +179,7 @@ class OpenLLMLargeLanguageModel(LargeLanguageModel):
index=0,
message=AssistantPromptMessage(content=message.content, tool_calls=[]),
usage=usage,
finish_reason=message.stop_reason if message.stop_reason else None,
finish_reason=message.stop_reason or None,
),
)
else:
@@ -189,7 +189,7 @@ class OpenLLMLargeLanguageModel(LargeLanguageModel):
delta=LLMResultChunkDelta(
index=0,
message=AssistantPromptMessage(content=message.content, tool_calls=[]),
finish_reason=message.stop_reason if message.stop_reason else None,
finish_reason=message.stop_reason or None,
),
)


+ 1
- 1
api/core/model_runtime/model_providers/openllm/llm/openllm_generate.py Dosyayı Görüntüle

@@ -106,7 +106,7 @@ class OpenLLMGenerate:
timeout = 120

data = {
"stop": stop if stop else [],
"stop": stop or [],
"prompt": "\n".join([message.content for message in prompt_messages]),
"llm_config": default_llm_config,
}

+ 1
- 1
api/core/model_runtime/model_providers/replicate/llm/llm.py Dosyayı Görüntüle

@@ -214,7 +214,7 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel):

index += 1

assistant_prompt_message = AssistantPromptMessage(content=output if output else "")
assistant_prompt_message = AssistantPromptMessage(content=output or "")

if index < prediction_output_length:
yield LLMResultChunk(

+ 2
- 1
api/core/model_runtime/model_providers/sagemaker/rerank/rerank.py Dosyayı Görüntüle

@@ -1,5 +1,6 @@
import json
import logging
import operator
from typing import Any, Optional

import boto3
@@ -94,7 +95,7 @@ class SageMakerRerankModel(RerankModel):
for idx in range(len(scores)):
candidate_docs.append({"content": docs[idx], "score": scores[idx]})

sorted(candidate_docs, key=lambda x: x["score"], reverse=True)
sorted(candidate_docs, key=operator.itemgetter("score"), reverse=True)

line = 3
rerank_documents = []

+ 1
- 1
api/core/model_runtime/model_providers/sagemaker/tts/tts.py Dosyayı Görüntüle

@@ -260,7 +260,7 @@ class SageMakerText2SpeechModel(TTSModel):
for payload in payloads
]

for index, future in enumerate(futures):
for future in futures:
resp = future.result()
audio_bytes = requests.get(resp.get("s3_presign_url")).content
for i in range(0, len(audio_bytes), 1024):

+ 1
- 1
api/core/model_runtime/model_providers/spark/llm/llm.py Dosyayı Görüntüle

@@ -220,7 +220,7 @@ class SparkLargeLanguageModel(LargeLanguageModel):
delta = content

assistant_prompt_message = AssistantPromptMessage(
content=delta if delta else "",
content=delta or "",
)

prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)

+ 2
- 1
api/core/model_runtime/model_providers/tencent/speech2text/flash_recognizer.py Dosyayı Görüntüle

@@ -1,6 +1,7 @@
import base64
import hashlib
import hmac
import operator
import time
import requests
@@ -127,7 +128,7 @@ class FlashRecognizer:
return s
def _build_req_with_signature(self, secret_key, params, header):
query = sorted(params.items(), key=lambda d: d[0])
query = sorted(params.items(), key=operator.itemgetter(0))
signstr = self._format_sign_string(query)
signature = self._sign(signstr, secret_key)
header["Authorization"] = signature

+ 2
- 2
api/core/model_runtime/model_providers/tongyi/llm/llm.py Dosyayı Görüntüle

@@ -4,6 +4,7 @@ import tempfile
import uuid
from collections.abc import Generator
from http import HTTPStatus
from pathlib import Path
from typing import Optional, Union, cast

from dashscope import Generation, MultiModalConversation, get_tokenizer
@@ -454,8 +455,7 @@ class TongyiLargeLanguageModel(LargeLanguageModel):

file_path = os.path.join(temp_dir, f"{uuid.uuid4()}.{mime_type.split('/')[1]}")

with open(file_path, "wb") as image_file:
image_file.write(base64.b64decode(encoded_string))
Path(file_path).write_bytes(base64.b64decode(encoded_string))

return f"file://{file_path}"


+ 2
- 4
api/core/model_runtime/model_providers/upstage/llm/llm.py Dosyayı Görüntüle

@@ -368,11 +368,9 @@ class UpstageLargeLanguageModel(_CommonUpstage, LargeLanguageModel):
final_tool_calls.extend(tool_calls)

# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(
content=delta.delta.content if delta.delta.content else "", tool_calls=tool_calls
)
assistant_prompt_message = AssistantPromptMessage(content=delta.delta.content or "", tool_calls=tool_calls)

full_assistant_content += delta.delta.content if delta.delta.content else ""
full_assistant_content += delta.delta.content or ""

if has_finish_reason:
final_chunk = LLMResultChunk(

+ 2
- 2
api/core/model_runtime/model_providers/vertex_ai/llm/llm.py Dosyayı Görüntüle

@@ -231,10 +231,10 @@ class VertexAiLargeLanguageModel(LargeLanguageModel):
),
)
elif isinstance(chunk, ContentBlockDeltaEvent):
chunk_text = chunk.delta.text if chunk.delta.text else ""
chunk_text = chunk.delta.text or ""
full_assistant_content += chunk_text
assistant_prompt_message = AssistantPromptMessage(
content=chunk_text if chunk_text else "",
content=chunk_text or "",
)
index = chunk.index
yield LLMResultChunk(

+ 2
- 1
api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/base/auth.py Dosyayı Görüntüle

@@ -1,5 +1,6 @@
# coding : utf-8
import datetime
from itertools import starmap

import pytz

@@ -48,7 +49,7 @@ class SignResult:
self.authorization = ""

def __str__(self):
return "\n".join(["{}:{}".format(*item) for item in self.__dict__.items()])
return "\n".join(list(starmap("{}:{}".format, self.__dict__.items())))


class Credentials:

+ 2
- 1
api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/base/util.py Dosyayı Görüntüle

@@ -1,5 +1,6 @@
import hashlib
import hmac
import operator
from functools import reduce
from urllib.parse import quote

@@ -40,4 +41,4 @@ class Util:
if len(hv) == 1:
hv = "0" + hv
lst.append(hv)
return reduce(lambda x, y: x + y, lst)
return reduce(operator.add, lst)

+ 3
- 5
api/core/model_runtime/model_providers/volcengine_maas/llm/llm.py Dosyayı Görüntüle

@@ -174,9 +174,7 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=index,
message=AssistantPromptMessage(
content=message["content"] if message["content"] else "", tool_calls=[]
),
message=AssistantPromptMessage(content=message["content"] or "", tool_calls=[]),
usage=usage,
finish_reason=choice.get("finish_reason"),
),
@@ -208,7 +206,7 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
model=model,
prompt_messages=prompt_messages,
message=AssistantPromptMessage(
content=message["content"] if message["content"] else "",
content=message["content"] or "",
tool_calls=tool_calls,
),
usage=self._calc_response_usage(
@@ -284,7 +282,7 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
model=model,
prompt_messages=prompt_messages,
message=AssistantPromptMessage(
content=message.content if message.content else "",
content=message.content or "",
tool_calls=tool_calls,
),
usage=self._calc_response_usage(

+ 3
- 3
api/core/model_runtime/model_providers/wenxin/llm/llm.py Dosyayı Görüntüle

@@ -199,7 +199,7 @@ class ErnieBotLargeLanguageModel(LargeLanguageModel):
secret_key=credentials["secret_key"],
)

user = user if user else "ErnieBotDefault"
user = user or "ErnieBotDefault"

# convert prompt messages to baichuan messages
messages = [
@@ -289,7 +289,7 @@ class ErnieBotLargeLanguageModel(LargeLanguageModel):
index=0,
message=AssistantPromptMessage(content=message.content, tool_calls=[]),
usage=usage,
finish_reason=message.stop_reason if message.stop_reason else None,
finish_reason=message.stop_reason or None,
),
)
else:
@@ -299,7 +299,7 @@ class ErnieBotLargeLanguageModel(LargeLanguageModel):
delta=LLMResultChunkDelta(
index=0,
message=AssistantPromptMessage(content=message.content, tool_calls=[]),
finish_reason=message.stop_reason if message.stop_reason else None,
finish_reason=message.stop_reason or None,
),
)


+ 1
- 1
api/core/model_runtime/model_providers/wenxin/text_embedding/text_embedding.py Dosyayı Görüntüle

@@ -85,7 +85,7 @@ class WenxinTextEmbeddingModel(TextEmbeddingModel):
api_key = credentials["api_key"]
secret_key = credentials["secret_key"]
embedding: TextEmbedding = self._create_text_embedding(api_key, secret_key)
user = user if user else "ErnieBotDefault"
user = user or "ErnieBotDefault"

context_size = self._get_context_size(model, credentials)
max_chunks = self._get_max_chunks(model, credentials)

+ 3
- 3
api/core/model_runtime/model_providers/xinference/llm/llm.py Dosyayı Görüntüle

@@ -589,7 +589,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):

# convert tool call to assistant message tool call
tool_calls = assistant_message.tool_calls
assistant_prompt_message_tool_calls = self._extract_response_tool_calls(tool_calls if tool_calls else [])
assistant_prompt_message_tool_calls = self._extract_response_tool_calls(tool_calls or [])
function_call = assistant_message.function_call
if function_call:
assistant_prompt_message_tool_calls += [self._extract_response_function_call(function_call)]
@@ -652,7 +652,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):

# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(
content=delta.delta.content if delta.delta.content else "", tool_calls=assistant_message_tool_calls
content=delta.delta.content or "", tool_calls=assistant_message_tool_calls
)

if delta.finish_reason is not None:
@@ -749,7 +749,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
delta = chunk.choices[0]

# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(content=delta.text if delta.text else "", tool_calls=[])
assistant_prompt_message = AssistantPromptMessage(content=delta.text or "", tool_calls=[])

if delta.finish_reason is not None:
# temp_assistant_prompt_message is used to calculate usage

+ 1
- 1
api/core/model_runtime/model_providers/xinference/tts/tts.py Dosyayı Görüntüle

@@ -215,7 +215,7 @@ class XinferenceText2SpeechModel(TTSModel):
for i in range(len(sentences))
]

for index, future in enumerate(futures):
for future in futures:
response = future.result()
for i in range(0, len(response), 1024):
yield response[i : i + 1024]

+ 2
- 2
api/core/model_runtime/model_providers/zhipuai/llm/llm.py Dosyayı Görüntüle

@@ -414,10 +414,10 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):

# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(
content=delta.delta.content if delta.delta.content else "", tool_calls=assistant_tool_calls
content=delta.delta.content or "", tool_calls=assistant_tool_calls
)

full_assistant_content += delta.delta.content if delta.delta.content else ""
full_assistant_content += delta.delta.content or ""

if delta.finish_reason is not None and chunk.usage is not None:
completion_tokens = chunk.usage.completion_tokens

+ 3
- 1
api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_http_client.py Dosyayı Görüntüle

@@ -30,6 +30,8 @@ def _merge_map(map1: Mapping, map2: Mapping) -> Mapping:
return {key: val for key, val in merged.items() if val is not None}


from itertools import starmap

from httpx._config import DEFAULT_TIMEOUT_CONFIG as HTTPX_DEFAULT_TIMEOUT

ZHIPUAI_DEFAULT_TIMEOUT = httpx.Timeout(timeout=300.0, connect=8.0)
@@ -159,7 +161,7 @@ class HttpClient:
return [(key, str_data)]

def _make_multipartform(self, data: Mapping[object, object]) -> dict[str, object]:
items = flatten([self._object_to_formdata(k, v) for k, v in data.items()])
items = flatten(list(starmap(self._object_to_formdata, data.items())))

serialized: dict[str, object] = {}
for key, value in items:

+ 10
- 12
api/core/ops/langfuse_trace/langfuse_trace.py Dosyayı Görüntüle

@@ -65,7 +65,7 @@ class LangFuseDataTrace(BaseTraceInstance):
self.generate_name_trace(trace_info)

def workflow_trace(self, trace_info: WorkflowTraceInfo):
trace_id = trace_info.workflow_app_log_id if trace_info.workflow_app_log_id else trace_info.workflow_run_id
trace_id = trace_info.workflow_app_log_id or trace_info.workflow_run_id
user_id = trace_info.metadata.get("user_id")
if trace_info.message_id:
trace_id = trace_info.message_id
@@ -84,7 +84,7 @@ class LangFuseDataTrace(BaseTraceInstance):
)
self.add_trace(langfuse_trace_data=trace_data)
workflow_span_data = LangfuseSpan(
id=(trace_info.workflow_app_log_id if trace_info.workflow_app_log_id else trace_info.workflow_run_id),
id=(trace_info.workflow_app_log_id or trace_info.workflow_run_id),
name=TraceTaskName.WORKFLOW_TRACE.value,
input=trace_info.workflow_run_inputs,
output=trace_info.workflow_run_outputs,
@@ -93,7 +93,7 @@ class LangFuseDataTrace(BaseTraceInstance):
end_time=trace_info.end_time,
metadata=trace_info.metadata,
level=LevelEnum.DEFAULT if trace_info.error == "" else LevelEnum.ERROR,
status_message=trace_info.error if trace_info.error else "",
status_message=trace_info.error or "",
)
self.add_span(langfuse_span_data=workflow_span_data)
else:
@@ -143,7 +143,7 @@ class LangFuseDataTrace(BaseTraceInstance):
else:
inputs = json.loads(node_execution.inputs) if node_execution.inputs else {}
outputs = json.loads(node_execution.outputs) if node_execution.outputs else {}
created_at = node_execution.created_at if node_execution.created_at else datetime.now()
created_at = node_execution.created_at or datetime.now()
elapsed_time = node_execution.elapsed_time
finished_at = created_at + timedelta(seconds=elapsed_time)

@@ -172,10 +172,8 @@ class LangFuseDataTrace(BaseTraceInstance):
end_time=finished_at,
metadata=metadata,
level=(LevelEnum.DEFAULT if status == "succeeded" else LevelEnum.ERROR),
status_message=trace_info.error if trace_info.error else "",
parent_observation_id=(
trace_info.workflow_app_log_id if trace_info.workflow_app_log_id else trace_info.workflow_run_id
),
status_message=trace_info.error or "",
parent_observation_id=(trace_info.workflow_app_log_id or trace_info.workflow_run_id),
)
else:
span_data = LangfuseSpan(
@@ -188,7 +186,7 @@ class LangFuseDataTrace(BaseTraceInstance):
end_time=finished_at,
metadata=metadata,
level=(LevelEnum.DEFAULT if status == "succeeded" else LevelEnum.ERROR),
status_message=trace_info.error if trace_info.error else "",
status_message=trace_info.error or "",
)

self.add_span(langfuse_span_data=span_data)
@@ -212,7 +210,7 @@ class LangFuseDataTrace(BaseTraceInstance):
output=outputs,
metadata=metadata,
level=(LevelEnum.DEFAULT if status == "succeeded" else LevelEnum.ERROR),
status_message=trace_info.error if trace_info.error else "",
status_message=trace_info.error or "",
usage=generation_usage,
)

@@ -277,7 +275,7 @@ class LangFuseDataTrace(BaseTraceInstance):
output=message_data.answer,
metadata=metadata,
level=(LevelEnum.DEFAULT if message_data.status != "error" else LevelEnum.ERROR),
status_message=message_data.error if message_data.error else "",
status_message=message_data.error or "",
usage=generation_usage,
)

@@ -319,7 +317,7 @@ class LangFuseDataTrace(BaseTraceInstance):
end_time=trace_info.end_time,
metadata=trace_info.metadata,
level=(LevelEnum.DEFAULT if message_data.status != "error" else LevelEnum.ERROR),
status_message=message_data.error if message_data.error else "",
status_message=message_data.error or "",
usage=generation_usage,
)


+ 4
- 6
api/core/ops/langsmith_trace/langsmith_trace.py Dosyayı Görüntüle

@@ -82,7 +82,7 @@ class LangSmithDataTrace(BaseTraceInstance):
langsmith_run = LangSmithRunModel(
file_list=trace_info.file_list,
total_tokens=trace_info.total_tokens,
id=trace_info.workflow_app_log_id if trace_info.workflow_app_log_id else trace_info.workflow_run_id,
id=trace_info.workflow_app_log_id or trace_info.workflow_run_id,
name=TraceTaskName.WORKFLOW_TRACE.value,
inputs=trace_info.workflow_run_inputs,
run_type=LangSmithRunType.tool,
@@ -94,7 +94,7 @@ class LangSmithDataTrace(BaseTraceInstance):
},
error=trace_info.error,
tags=["workflow"],
parent_run_id=trace_info.message_id if trace_info.message_id else None,
parent_run_id=trace_info.message_id or None,
)

self.add_run(langsmith_run)
@@ -133,7 +133,7 @@ class LangSmithDataTrace(BaseTraceInstance):
else:
inputs = json.loads(node_execution.inputs) if node_execution.inputs else {}
outputs = json.loads(node_execution.outputs) if node_execution.outputs else {}
created_at = node_execution.created_at if node_execution.created_at else datetime.now()
created_at = node_execution.created_at or datetime.now()
elapsed_time = node_execution.elapsed_time
finished_at = created_at + timedelta(seconds=elapsed_time)

@@ -180,9 +180,7 @@ class LangSmithDataTrace(BaseTraceInstance):
extra={
"metadata": metadata,
},
parent_run_id=trace_info.workflow_app_log_id
if trace_info.workflow_app_log_id
else trace_info.workflow_run_id,
parent_run_id=trace_info.workflow_app_log_id or trace_info.workflow_run_id,
tags=["node_execution"],
)


+ 7
- 8
api/core/ops/ops_trace_manager.py Dosyayı Görüntüle

@@ -354,11 +354,11 @@ class TraceTask:
workflow_run_inputs = json.loads(workflow_run.inputs) if workflow_run.inputs else {}
workflow_run_outputs = json.loads(workflow_run.outputs) if workflow_run.outputs else {}
workflow_run_version = workflow_run.version
error = workflow_run.error if workflow_run.error else ""
error = workflow_run.error or ""

total_tokens = workflow_run.total_tokens

file_list = workflow_run_inputs.get("sys.file") if workflow_run_inputs.get("sys.file") else []
file_list = workflow_run_inputs.get("sys.file") or []
query = workflow_run_inputs.get("query") or workflow_run_inputs.get("sys.query") or ""

# get workflow_app_log_id
@@ -452,7 +452,7 @@ class TraceTask:
message_tokens=message_tokens,
answer_tokens=message_data.answer_tokens,
total_tokens=message_tokens + message_data.answer_tokens,
error=message_data.error if message_data.error else "",
error=message_data.error or "",
inputs=inputs,
outputs=message_data.answer,
file_list=file_list,
@@ -487,7 +487,7 @@ class TraceTask:
workflow_app_log_id = str(workflow_app_log_data.id) if workflow_app_log_data else None

moderation_trace_info = ModerationTraceInfo(
message_id=workflow_app_log_id if workflow_app_log_id else message_id,
message_id=workflow_app_log_id or message_id,
inputs=inputs,
message_data=message_data.to_dict(),
flagged=moderation_result.flagged,
@@ -527,7 +527,7 @@ class TraceTask:
workflow_app_log_id = str(workflow_app_log_data.id) if workflow_app_log_data else None

suggested_question_trace_info = SuggestedQuestionTraceInfo(
message_id=workflow_app_log_id if workflow_app_log_id else message_id,
message_id=workflow_app_log_id or message_id,
message_data=message_data.to_dict(),
inputs=message_data.message,
outputs=message_data.answer,
@@ -569,7 +569,7 @@ class TraceTask:

dataset_retrieval_trace_info = DatasetRetrievalTraceInfo(
message_id=message_id,
inputs=message_data.query if message_data.query else message_data.inputs,
inputs=message_data.query or message_data.inputs,
documents=[doc.model_dump() for doc in documents],
start_time=timer.get("start"),
end_time=timer.get("end"),
@@ -695,8 +695,7 @@ class TraceQueueManager:
self.start_timer()

def add_trace_task(self, trace_task: TraceTask):
global trace_manager_timer
global trace_manager_queue
global trace_manager_timer, trace_manager_queue
try:
if self.trace_instance:
trace_task.app_id = self.app_id

+ 3
- 3
api/core/prompt/simple_prompt_transform.py Dosyayı Görüntüle

@@ -112,11 +112,11 @@ class SimplePromptTransform(PromptTransform):
for v in prompt_template_config["special_variable_keys"]:
# support #context#, #query# and #histories#
if v == "#context#":
variables["#context#"] = context if context else ""
variables["#context#"] = context or ""
elif v == "#query#":
variables["#query#"] = query if query else ""
variables["#query#"] = query or ""
elif v == "#histories#":
variables["#histories#"] = histories if histories else ""
variables["#histories#"] = histories or ""

prompt_template = prompt_template_config["prompt_template"]
prompt = prompt_template.format(variables)

+ 1
- 1
api/core/rag/datasource/keyword/keyword_base.py Dosyayı Görüntüle

@@ -34,7 +34,7 @@ class BaseKeyword(ABC):
raise NotImplementedError

def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]:
for text in texts[:]:
for text in texts.copy():
doc_id = text.metadata["doc_id"]
exists_duplicate_node = self.text_exists(doc_id)
if exists_duplicate_node:

+ 2
- 2
api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py Dosyayı Görüntüle

@@ -239,7 +239,7 @@ class AnalyticdbVector(BaseVector):
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models

score_threshold = kwargs.get("score_threshold", 0.0) if kwargs.get("score_threshold", 0.0) else 0.0
score_threshold = kwargs.get("score_threshold", 0.0)
request = gpdb_20160503_models.QueryCollectionDataRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,
@@ -267,7 +267,7 @@ class AnalyticdbVector(BaseVector):
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models

score_threshold = kwargs.get("score_threshold", 0.0) if kwargs.get("score_threshold", 0.0) else 0.0
score_threshold = kwargs.get("score_threshold", 0.0)
request = gpdb_20160503_models.QueryCollectionDataRequest(
dbinstance_id=self.config.instance_id,
region_id=self.config.region_id,

+ 1
- 1
api/core/rag/datasource/vdb/chroma/chroma_vector.py Dosyayı Görüntüle

@@ -92,7 +92,7 @@ class ChromaVector(BaseVector):
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
collection = self._client.get_or_create_collection(self._collection_name)
results: QueryResult = collection.query(query_embeddings=query_vector, n_results=kwargs.get("top_k", 4))
score_threshold = kwargs.get("score_threshold", 0.0) if kwargs.get("score_threshold", 0.0) else 0.0
score_threshold = kwargs.get("score_threshold", 0.0)

ids: list[str] = results["ids"][0]
documents: list[str] = results["documents"][0]

+ 3
- 3
api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py Dosyayı Görüntüle

@@ -86,8 +86,8 @@ class ElasticSearchVector(BaseVector):
id=uuids[i],
document={
Field.CONTENT_KEY.value: documents[i].page_content,
Field.VECTOR.value: embeddings[i] if embeddings[i] else None,
Field.METADATA_KEY.value: documents[i].metadata if documents[i].metadata else {},
Field.VECTOR.value: embeddings[i] or None,
Field.METADATA_KEY.value: documents[i].metadata or {},
},
)
self._client.indices.refresh(index=self._collection_name)
@@ -131,7 +131,7 @@ class ElasticSearchVector(BaseVector):

docs = []
for doc, score in docs_and_scores:
score_threshold = kwargs.get("score_threshold", 0.0) if kwargs.get("score_threshold", 0.0) else 0.0
score_threshold = kwargs.get("score_threshold", 0.0)
if score > score_threshold:
doc.metadata["score"] = score
docs.append(doc)

+ 1
- 1
api/core/rag/datasource/vdb/milvus/milvus_vector.py Dosyayı Görüntüle

@@ -141,7 +141,7 @@ class MilvusVector(BaseVector):
for result in results[0]:
metadata = result["entity"].get(Field.METADATA_KEY.value)
metadata["score"] = result["distance"]
score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0
score_threshold = kwargs.get("score_threshold", 0.0)
if result["distance"] > score_threshold:
doc = Document(page_content=result["entity"].get(Field.CONTENT_KEY.value), metadata=metadata)
docs.append(doc)

+ 1
- 1
api/core/rag/datasource/vdb/myscale/myscale_vector.py Dosyayı Görüntüle

@@ -122,7 +122,7 @@ class MyScaleVector(BaseVector):

def _search(self, dist: str, order: SortOrder, **kwargs: Any) -> list[Document]:
top_k = kwargs.get("top_k", 5)
score_threshold = kwargs.get("score_threshold") or 0.0
score_threshold = kwargs.get("score_threshold", 0.0)
where_str = (
f"WHERE dist < {1 - score_threshold}"
if self._metric.upper() == "COSINE" and order == SortOrder.ASC and score_threshold > 0.0

+ 1
- 1
api/core/rag/datasource/vdb/opensearch/opensearch_vector.py Dosyayı Görüntüle

@@ -170,7 +170,7 @@ class OpenSearchVector(BaseVector):
metadata = {}

metadata["score"] = hit["_score"]
score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0
score_threshold = kwargs.get("score_threshold", 0.0)
if hit["_score"] > score_threshold:
doc = Document(page_content=hit["_source"].get(Field.CONTENT_KEY.value), metadata=metadata)
docs.append(doc)

+ 2
- 2
api/core/rag/datasource/vdb/oracle/oraclevector.py Dosyayı Görüntüle

@@ -200,7 +200,7 @@ class OracleVector(BaseVector):
[numpy.array(query_vector)],
)
docs = []
score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0
score_threshold = kwargs.get("score_threshold", 0.0)
for record in cur:
metadata, text, distance = record
score = 1 - distance
@@ -212,7 +212,7 @@ class OracleVector(BaseVector):
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
top_k = kwargs.get("top_k", 5)
# just not implement fetch by score_threshold now, may be later
score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0
score_threshold = kwargs.get("score_threshold", 0.0)
if len(query) > 0:
# Check which language the query is in
zh_pattern = re.compile("[\u4e00-\u9fa5]+")

+ 1
- 1
api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py Dosyayı Görüntüle

@@ -198,7 +198,7 @@ class PGVectoRS(BaseVector):
metadata = record.meta
score = 1 - dis
metadata["score"] = score
score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0
score_threshold = kwargs.get("score_threshold", 0.0)
if score > score_threshold:
doc = Document(page_content=record.text, metadata=metadata)
docs.append(doc)

+ 1
- 1
api/core/rag/datasource/vdb/pgvector/pgvector.py Dosyayı Görüntüle

@@ -144,7 +144,7 @@ class PGVector(BaseVector):
(json.dumps(query_vector),),
)
docs = []
score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0
score_threshold = kwargs.get("score_threshold", 0.0)
for record in cur:
metadata, text, distance = record
score = 1 - distance

+ 1
- 1
api/core/rag/datasource/vdb/qdrant/qdrant_vector.py Dosyayı Görüntüle

@@ -339,7 +339,7 @@ class QdrantVector(BaseVector):
for result in results:
metadata = result.payload.get(Field.METADATA_KEY.value) or {}
# duplicate check score threshold
score_threshold = kwargs.get("score_threshold", 0.0) if kwargs.get("score_threshold", 0.0) else 0.0
score_threshold = kwargs.get("score_threshold", 0.0)
if result.score > score_threshold:
metadata["score"] = result.score
doc = Document(

+ 1
- 1
api/core/rag/datasource/vdb/relyt/relyt_vector.py Dosyayı Görüntüle

@@ -230,7 +230,7 @@ class RelytVector(BaseVector):
# Organize results.
docs = []
for document, score in results:
score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0
score_threshold = kwargs.get("score_threshold", 0.0)
if 1 - score > score_threshold:
docs.append(document)
return docs

+ 1
- 1
api/core/rag/datasource/vdb/tencent/tencent_vector.py Dosyayı Görüntüle

@@ -153,7 +153,7 @@ class TencentVector(BaseVector):
limit=kwargs.get("top_k", 4),
timeout=self._client_config.timeout,
)
score_threshold = kwargs.get("score_threshold", 0.0) if kwargs.get("score_threshold", 0.0) else 0.0
score_threshold = kwargs.get("score_threshold", 0.0)
return self._get_search_res(res, score_threshold)

def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:

+ 1
- 1
api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py Dosyayı Görüntüle

@@ -185,7 +185,7 @@ class TiDBVector(BaseVector):

def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
top_k = kwargs.get("top_k", 5)
score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0
score_threshold = kwargs.get("score_threshold", 0.0)
filter = kwargs.get("filter")
distance = 1 - score_threshold


+ 1
- 1
api/core/rag/datasource/vdb/vector_base.py Dosyayı Görüntüle

@@ -49,7 +49,7 @@ class BaseVector(ABC):
raise NotImplementedError

def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]:
for text in texts[:]:
for text in texts.copy():
doc_id = text.metadata["doc_id"]
exists_duplicate_node = self.text_exists(doc_id)
if exists_duplicate_node:

+ 1
- 1
api/core/rag/datasource/vdb/vector_factory.py Dosyayı Görüntüle

@@ -153,7 +153,7 @@ class Vector:
return CacheEmbedding(embedding_model)

def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]:
for text in texts[:]:
for text in texts.copy():
doc_id = text.metadata["doc_id"]
exists_duplicate_node = self.text_exists(doc_id)
if exists_duplicate_node:

+ 1
- 1
api/core/rag/datasource/vdb/weaviate/weaviate_vector.py Dosyayı Görüntüle

@@ -205,7 +205,7 @@ class WeaviateVector(BaseVector):

docs = []
for doc, score in docs_and_scores:
score_threshold = kwargs.get("score_threshold", 0.0) if kwargs.get("score_threshold", 0.0) else 0.0
score_threshold = kwargs.get("score_threshold", 0.0)
# check score threshold
if score > score_threshold:
doc.metadata["score"] = score

+ 3
- 5
api/core/rag/extractor/blob/blob.py Dosyayı Görüntüle

@@ -12,7 +12,7 @@ import mimetypes
from abc import ABC, abstractmethod
from collections.abc import Generator, Iterable, Mapping
from io import BufferedReader, BytesIO
from pathlib import PurePath
from pathlib import Path, PurePath
from typing import Any, Optional, Union

from pydantic import BaseModel, ConfigDict, model_validator
@@ -56,8 +56,7 @@ class Blob(BaseModel):
def as_string(self) -> str:
"""Read data as a string."""
if self.data is None and self.path:
with open(str(self.path), encoding=self.encoding) as f:
return f.read()
return Path(str(self.path)).read_text(encoding=self.encoding)
elif isinstance(self.data, bytes):
return self.data.decode(self.encoding)
elif isinstance(self.data, str):
@@ -72,8 +71,7 @@ class Blob(BaseModel):
elif isinstance(self.data, str):
return self.data.encode(self.encoding)
elif self.data is None and self.path:
with open(str(self.path), "rb") as f:
return f.read()
return Path(str(self.path)).read_bytes()
else:
raise ValueError(f"Unable to get bytes for blob {self}")


+ 3
- 4
api/core/rag/extractor/extract_processor.py Dosyayı Görüntüle

@@ -68,8 +68,7 @@ class ExtractProcessor:
suffix = "." + re.search(r"\.(\w+)$", filename).group(1)

file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}"
with open(file_path, "wb") as file:
file.write(response.content)
Path(file_path).write_bytes(response.content)
extract_setting = ExtractSetting(datasource_type="upload_file", document_model="text_model")
if return_text:
delimiter = "\n"
@@ -111,7 +110,7 @@ class ExtractProcessor:
)
elif file_extension in [".htm", ".html"]:
extractor = HtmlExtractor(file_path)
elif file_extension in [".docx"]:
elif file_extension == ".docx":
extractor = WordExtractor(file_path, upload_file.tenant_id, upload_file.created_by)
elif file_extension == ".csv":
extractor = CSVExtractor(file_path, autodetect_encoding=True)
@@ -143,7 +142,7 @@ class ExtractProcessor:
extractor = MarkdownExtractor(file_path, autodetect_encoding=True)
elif file_extension in [".htm", ".html"]:
extractor = HtmlExtractor(file_path)
elif file_extension in [".docx"]:
elif file_extension == ".docx":
extractor = WordExtractor(file_path, upload_file.tenant_id, upload_file.created_by)
elif file_extension == ".csv":
extractor = CSVExtractor(file_path, autodetect_encoding=True)

+ 2
- 2
api/core/rag/extractor/helpers.py Dosyayı Görüntüle

@@ -1,6 +1,7 @@
"""Document loader helpers."""

import concurrent.futures
from pathlib import Path
from typing import NamedTuple, Optional, cast


@@ -28,8 +29,7 @@ def detect_file_encodings(file_path: str, timeout: int = 5) -> list[FileEncoding
import chardet

def read_and_detect(file_path: str) -> list[dict]:
with open(file_path, "rb") as f:
rawdata = f.read()
rawdata = Path(file_path).read_bytes()
return cast(list[dict], chardet.detect_all(rawdata))

with concurrent.futures.ThreadPoolExecutor() as executor:

+ 3
- 4
api/core/rag/extractor/markdown_extractor.py Dosyayı Görüntüle

@@ -1,6 +1,7 @@
"""Abstract interface for document loader implementations."""

import re
from pathlib import Path
from typing import Optional, cast

from core.rag.extractor.extractor_base import BaseExtractor
@@ -102,15 +103,13 @@ class MarkdownExtractor(BaseExtractor):
"""Parse file into tuples."""
content = ""
try:
with open(filepath, encoding=self._encoding) as f:
content = f.read()
content = Path(filepath).read_text(encoding=self._encoding)
except UnicodeDecodeError as e:
if self._autodetect_encoding:
detected_encodings = detect_file_encodings(filepath)
for encoding in detected_encodings:
try:
with open(filepath, encoding=encoding.encoding) as f:
content = f.read()
content = Path(filepath).read_text(encoding=encoding.encoding)
break
except UnicodeDecodeError:
continue

+ 3
- 4
api/core/rag/extractor/text_extractor.py Dosyayı Görüntüle

@@ -1,5 +1,6 @@
"""Abstract interface for document loader implementations."""

from pathlib import Path
from typing import Optional

from core.rag.extractor.extractor_base import BaseExtractor
@@ -25,15 +26,13 @@ class TextExtractor(BaseExtractor):
"""Load from file path."""
text = ""
try:
with open(self._file_path, encoding=self._encoding) as f:
text = f.read()
text = Path(self._file_path).read_text(encoding=self._encoding)
except UnicodeDecodeError as e:
if self._autodetect_encoding:
detected_encodings = detect_file_encodings(self._file_path)
for encoding in detected_encodings:
try:
with open(self._file_path, encoding=encoding.encoding) as f:
text = f.read()
text = Path(self._file_path).read_text(encoding=encoding.encoding)
break
except UnicodeDecodeError:
continue

+ 1
- 1
api/core/rag/extractor/word_extractor.py Dosyayı Görüntüle

@@ -153,7 +153,7 @@ class WordExtractor(BaseExtractor):
if col_index >= total_cols:
break
cell_content = self._parse_cell(cell, image_map).strip()
cell_colspan = cell.grid_span if cell.grid_span else 1
cell_colspan = cell.grid_span or 1
for i in range(cell_colspan):
if col_index + i < total_cols:
row_cells[col_index + i] = cell_content if i == 0 else ""

+ 4
- 6
api/core/rag/retrieval/dataset_retrieval.py Dosyayı Görüntüle

@@ -256,7 +256,7 @@ class DatasetRetrieval:
# get retrieval model config
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
if dataset:
retrieval_model_config = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model
retrieval_model_config = dataset.retrieval_model or default_retrieval_model

# get top k
top_k = retrieval_model_config["top_k"]
@@ -410,7 +410,7 @@ class DatasetRetrieval:
return []

# get retrieval model , if the model is not setting , using default
retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model
retrieval_model = dataset.retrieval_model or default_retrieval_model

if dataset.indexing_technique == "economy":
# use keyword table query
@@ -433,9 +433,7 @@ class DatasetRetrieval:
reranking_model=retrieval_model.get("reranking_model", None)
if retrieval_model["reranking_enable"]
else None,
reranking_mode=retrieval_model.get("reranking_mode")
if retrieval_model.get("reranking_mode")
else "reranking_model",
reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model",
weights=retrieval_model.get("weights", None),
)

@@ -486,7 +484,7 @@ class DatasetRetrieval:
}

for dataset in available_datasets:
retrieval_model_config = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model
retrieval_model_config = dataset.retrieval_model or default_retrieval_model

# get top k
top_k = retrieval_model_config["top_k"]

+ 1
- 1
api/core/tools/provider/api_tool_provider.py Dosyayı Görüntüle

@@ -106,7 +106,7 @@ class ApiToolProviderController(ToolProviderController):
"human": {"en_US": tool_bundle.summary or "", "zh_Hans": tool_bundle.summary or ""},
"llm": tool_bundle.summary or "",
},
"parameters": tool_bundle.parameters if tool_bundle.parameters else [],
"parameters": tool_bundle.parameters or [],
}
)


+ 2
- 1
api/core/tools/provider/builtin/aws/tools/sagemaker_text_rerank.py Dosyayı Görüntüle

@@ -1,4 +1,5 @@
import json
import operator
from typing import Any, Union

import boto3
@@ -71,7 +72,7 @@ class SageMakerReRankTool(BuiltinTool):
candidate_docs[idx]["score"] = scores[idx]

line = 8
sorted_candidate_docs = sorted(candidate_docs, key=lambda x: x["score"], reverse=True)
sorted_candidate_docs = sorted(candidate_docs, key=operator.itemgetter("score"), reverse=True)

line = 9
return [self.create_json_message(res) for res in sorted_candidate_docs[: self.topk]]

+ 1
- 1
api/core/tools/provider/builtin/hap/tools/get_worksheet_fields.py Dosyayı Görüntüle

@@ -115,7 +115,7 @@ class GetWorksheetFieldsTool(BuiltinTool):
fields.append(field)
fields_list.append(
f"|{field['id']}|{field['name']}|{field['type']}|{field['typeId']}|{field['description']}"
f"|{field['options'] if field['options'] else ''}|"
f"|{field['options'] or ''}|"
)

fields.append(

+ 1
- 1
api/core/tools/provider/builtin/hap/tools/get_worksheet_pivot_data.py Dosyayı Görüntüle

@@ -130,7 +130,7 @@ class GetWorksheetPivotDataTool(BuiltinTool):
# ]
rows = []
for row in data["data"]:
row_data = row["rows"] if row["rows"] else {}
row_data = row["rows"] or {}
row_data.update(row["columns"])
row_data.update(row["values"])
rows.append(row_data)

+ 1
- 1
api/core/tools/provider/builtin/hap/tools/list_worksheet_records.py Dosyayı Görüntüle

@@ -113,7 +113,7 @@ class ListWorksheetRecordsTool(BuiltinTool):
result_text = f"Found {result['total']} rows in worksheet \"{worksheet_name}\"."
if result["total"] > 0:
result_text += (
f" The following are {result['total'] if result['total'] < limit else limit}"
f" The following are {min(limit, result['total'])}"
f" pieces of data presented in a table format:\n\n{table_header}"
)
for row in rows:

+ 1
- 1
api/core/tools/provider/builtin/searchapi/tools/youtube_transcripts.py Dosyayı Görüntüle

@@ -37,7 +37,7 @@ class SearchAPI:
return {
"engine": "youtube_transcripts",
"video_id": video_id,
"lang": language if language else "en",
"lang": language or "en",
**{key: value for key, value in kwargs.items() if value not in [None, ""]},
}


+ 2
- 4
api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py Dosyayı Görüntüle

@@ -160,7 +160,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
hit_callback.on_query(query, dataset.id)

# get retrieval model , if the model is not setting , using default
retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model
retrieval_model = dataset.retrieval_model or default_retrieval_model

if dataset.indexing_technique == "economy":
# use keyword table query
@@ -183,9 +183,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
reranking_model=retrieval_model.get("reranking_model", None)
if retrieval_model["reranking_enable"]
else None,
reranking_mode=retrieval_model.get("reranking_mode")
if retrieval_model.get("reranking_mode")
else "reranking_model",
reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model",
weights=retrieval_model.get("weights", None),
)


+ 2
- 4
api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py Dosyayı Görüntüle

@@ -55,7 +55,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
hit_callback.on_query(query, dataset.id)

# get retrieval model , if the model is not setting , using default
retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model
retrieval_model = dataset.retrieval_model or default_retrieval_model
if dataset.indexing_technique == "economy":
# use keyword table query
documents = RetrievalService.retrieve(
@@ -76,9 +76,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
reranking_model=retrieval_model.get("reranking_model", None)
if retrieval_model["reranking_enable"]
else None,
reranking_mode=retrieval_model.get("reranking_mode")
if retrieval_model.get("reranking_mode")
else "reranking_model",
reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model",
weights=retrieval_model.get("weights", None),
)
else:

+ 3
- 3
api/core/tools/utils/web_reader_tool.py Dosyayı Görüntüle

@@ -8,6 +8,7 @@ import subprocess
import tempfile
import unicodedata
from contextlib import contextmanager
from pathlib import Path
from urllib.parse import unquote

import chardet
@@ -98,7 +99,7 @@ def get_url(url: str, user_agent: str = None) -> str:
authors=a["byline"],
publish_date=a["date"],
top_image="",
text=a["plain_text"] if a["plain_text"] else "",
text=a["plain_text"] or "",
)

return res
@@ -117,8 +118,7 @@ def extract_using_readabilipy(html):
subprocess.check_call(["node", "ExtractArticle.js", "-i", html_path, "-o", article_json_path])

# Read output of call to Readability.parse() from JSON file and return as Python dictionary
with open(article_json_path, encoding="utf-8") as json_file:
input_json = json.loads(json_file.read())
input_json = json.loads(Path(article_json_path).read_text(encoding="utf-8"))

# Deleting files after processing
os.unlink(article_json_path)

+ 1
- 1
api/core/tools/utils/yaml_utils.py Dosyayı Görüntüle

@@ -21,7 +21,7 @@ def load_yaml_file(file_path: str, ignore_error: bool = True, default_value: Any
with open(file_path, encoding="utf-8") as yaml_file:
try:
yaml_content = yaml.safe_load(yaml_file)
return yaml_content if yaml_content else default_value
return yaml_content or default_value
except Exception as e:
raise YAMLError(f"Failed to load YAML file {file_path}: {e}")
except Exception as e:

+ 2
- 3
api/core/workflow/graph_engine/entities/graph.py Dosyayı Görüntüle

@@ -268,7 +268,7 @@ class Graph(BaseModel):
f"Node {graph_edge.source_node_id} is connected to the previous node, please check the graph."
)

new_route = route[:]
new_route = route.copy()
new_route.append(graph_edge.target_node_id)
cls._check_connected_to_previous_node(
route=new_route,
@@ -679,8 +679,7 @@ class Graph(BaseModel):
all_routes_node_ids = set()
parallel_start_node_ids: dict[str, list[str]] = {}
for branch_node_id, node_ids in routes_node_ids.items():
for node_id in node_ids:
all_routes_node_ids.add(node_id)
all_routes_node_ids.update(node_ids)

if branch_node_id in reverse_edge_mapping:
for graph_edge in reverse_edge_mapping[branch_node_id]:

+ 4
- 4
api/core/workflow/nodes/code/code_node.py Dosyayı Görüntüle

@@ -74,7 +74,7 @@ class CodeNode(BaseNode):
:return:
"""
if not isinstance(value, str):
if isinstance(value, type(None)):
if value is None:
return None
else:
raise ValueError(f"Output variable `{variable}` must be a string")
@@ -95,7 +95,7 @@ class CodeNode(BaseNode):
:return:
"""
if not isinstance(value, int | float):
if isinstance(value, type(None)):
if value is None:
return None
else:
raise ValueError(f"Output variable `{variable}` must be a number")
@@ -182,7 +182,7 @@ class CodeNode(BaseNode):
f"Output {prefix}.{output_name} is not a valid array."
f" make sure all elements are of the same type."
)
elif isinstance(output_value, type(None)):
elif output_value is None:
pass
else:
raise ValueError(f"Output {prefix}.{output_name} is not a valid type.")
@@ -284,7 +284,7 @@ class CodeNode(BaseNode):

for i, value in enumerate(result[output_name]):
if not isinstance(value, dict):
if isinstance(value, type(None)):
if value is None:
pass
else:
raise ValueError(

+ 1
- 1
api/core/workflow/nodes/if_else/if_else_node.py Dosyayı Görüntüle

@@ -79,7 +79,7 @@ class IfElseNode(BaseNode):
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=node_inputs,
process_data=process_datas,
edge_source_handle=selected_case_id if selected_case_id else "false", # Use case ID or 'default'
edge_source_handle=selected_case_id or "false", # Use case ID or 'default'
outputs=outputs,
)


+ 1
- 1
api/core/workflow/nodes/llm/llm_node.py Dosyayı Görüntüle

@@ -580,7 +580,7 @@ class LLMNode(BaseNode):
prompt_messages = prompt_transform.get_prompt(
prompt_template=node_data.prompt_template,
inputs=inputs,
query=query if query else "",
query=query or "",
files=files,
context=context,
memory_config=node_data.memory,

+ 1
- 1
api/core/workflow/nodes/question_classifier/question_classifier_node.py Dosyayı Görüntüle

@@ -250,7 +250,7 @@ class QuestionClassifierNode(LLMNode):
for class_ in classes:
category = {"category_id": class_.id, "category_name": class_.name}
categories.append(category)
instruction = node_data.instruction if node_data.instruction else ""
instruction = node_data.instruction or ""
input_text = query
memory_str = ""
if memory:

+ 1
- 2
api/events/event_handlers/update_app_dataset_join_when_app_model_config_updated.py Dosyayı Görüntüle

@@ -18,8 +18,7 @@ def handle(sender, **kwargs):
added_dataset_ids = dataset_ids
else:
old_dataset_ids = set()
for app_dataset_join in app_dataset_joins:
old_dataset_ids.add(app_dataset_join.dataset_id)
old_dataset_ids.update(app_dataset_join.dataset_id for app_dataset_join in app_dataset_joins)

added_dataset_ids = dataset_ids - old_dataset_ids
removed_dataset_ids = old_dataset_ids - dataset_ids

+ 1
- 2
api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py Dosyayı Görüntüle

@@ -22,8 +22,7 @@ def handle(sender, **kwargs):
added_dataset_ids = dataset_ids
else:
old_dataset_ids = set()
for app_dataset_join in app_dataset_joins:
old_dataset_ids.add(app_dataset_join.dataset_id)
old_dataset_ids.update(app_dataset_join.dataset_id for app_dataset_join in app_dataset_joins)

added_dataset_ids = dataset_ids - old_dataset_ids
removed_dataset_ids = old_dataset_ids - dataset_ids

+ 3
- 5
api/extensions/storage/local_storage.py Dosyayı Görüntüle

@@ -1,6 +1,7 @@
import os
import shutil
from collections.abc import Generator
from pathlib import Path

from flask import Flask

@@ -26,8 +27,7 @@ class LocalStorage(BaseStorage):
folder = os.path.dirname(filename)
os.makedirs(folder, exist_ok=True)

with open(os.path.join(os.getcwd(), filename), "wb") as f:
f.write(data)
Path(os.path.join(os.getcwd(), filename)).write_bytes(data)

def load_once(self, filename: str) -> bytes:
if not self.folder or self.folder.endswith("/"):
@@ -38,9 +38,7 @@ class LocalStorage(BaseStorage):
if not os.path.exists(filename):
raise FileNotFoundError("File not found")

with open(filename, "rb") as f:
data = f.read()

data = Path(filename).read_bytes()
return data

def load_stream(self, filename: str) -> Generator:

+ 2
- 2
api/models/dataset.py Dosyayı Görüntüle

@@ -144,7 +144,7 @@ class Dataset(db.Model):
"top_k": 2,
"score_threshold_enabled": False,
}
return self.retrieval_model if self.retrieval_model else default_retrieval_model
return self.retrieval_model or default_retrieval_model

@property
def tags(self):
@@ -160,7 +160,7 @@ class Dataset(db.Model):
.all()
)

return tags if tags else []
return tags or []

@staticmethod
def gen_collection_name_by_id(dataset_id: str) -> str:

+ 5
- 5
api/models/model.py Dosyayı Görüntüle

@@ -118,7 +118,7 @@ class App(db.Model):

@property
def api_base_url(self):
return (dify_config.SERVICE_API_URL if dify_config.SERVICE_API_URL else request.host_url.rstrip("/")) + "/v1"
return (dify_config.SERVICE_API_URL or request.host_url.rstrip("/")) + "/v1"

@property
def tenant(self):
@@ -207,7 +207,7 @@ class App(db.Model):
.all()
)

return tags if tags else []
return tags or []


class AppModelConfig(db.Model):
@@ -908,7 +908,7 @@ class Message(db.Model):
"id": message_file.id,
"type": message_file.type,
"url": url,
"belongs_to": message_file.belongs_to if message_file.belongs_to else "user",
"belongs_to": message_file.belongs_to or "user",
}
)

@@ -1212,7 +1212,7 @@ class Site(db.Model):

@property
def app_base_url(self):
return dify_config.APP_WEB_URL if dify_config.APP_WEB_URL else request.url_root.rstrip("/")
return dify_config.APP_WEB_URL or request.url_root.rstrip("/")


class ApiToken(db.Model):
@@ -1488,7 +1488,7 @@ class TraceAppConfig(db.Model):

@property
def tracing_config_dict(self):
return self.tracing_config if self.tracing_config else {}
return self.tracing_config or {}

@property
def tracing_config_str(self):

+ 3
- 0
api/pyproject.toml Dosyayı Görüntüle

@@ -15,6 +15,7 @@ select = [
"C4", # flake8-comprehensions
"E", # pycodestyle E rules
"F", # pyflakes rules
"FURB", # refurb rules
"I", # isort rules
"N", # pep8-naming
"RUF019", # unnecessary-key-check
@@ -37,6 +38,8 @@ ignore = [
"F405", # undefined-local-with-import-star-usage
"F821", # undefined-name
"F841", # unused-variable
"FURB113", # repeated-append
"FURB152", # math-constant
"UP007", # non-pep604-annotation
"UP032", # f-string
"B005", # strip-with-multi-characters

+ 1
- 1
api/services/account_service.py Dosyayı Görüntüle

@@ -544,7 +544,7 @@ class RegisterService:
"""Register account"""
try:
account = AccountService.create_account(
email=email, name=name, interface_language=language if language else languages[0], password=password
email=email, name=name, interface_language=language or languages[0], password=password
)
account.status = AccountStatus.ACTIVE.value if not status else status.value
account.initialized_at = datetime.now(timezone.utc).replace(tzinfo=None)

+ 5
- 7
api/services/app_dsl_service.py Dosyayı Görüntüle

@@ -81,13 +81,11 @@ class AppDslService:
raise ValueError("Missing app in data argument")

# get app basic info
name = args.get("name") if args.get("name") else app_data.get("name")
description = args.get("description") if args.get("description") else app_data.get("description", "")
icon_type = args.get("icon_type") if args.get("icon_type") else app_data.get("icon_type")
icon = args.get("icon") if args.get("icon") else app_data.get("icon")
icon_background = (
args.get("icon_background") if args.get("icon_background") else app_data.get("icon_background")
)
name = args.get("name") or app_data.get("name")
description = args.get("description") or app_data.get("description", "")
icon_type = args.get("icon_type") or app_data.get("icon_type")
icon = args.get("icon") or app_data.get("icon")
icon_background = args.get("icon_background") or app_data.get("icon_background")
use_icon_as_answer_icon = app_data.get("use_icon_as_answer_icon", False)

# import dsl and create app

+ 2
- 6
api/services/dataset_service.py Dosyayı Görüntüle

@@ -155,7 +155,7 @@ class DatasetService:
dataset.tenant_id = tenant_id
dataset.embedding_model_provider = embedding_model.provider if embedding_model else None
dataset.embedding_model = embedding_model.model if embedding_model else None
dataset.permission = permission if permission else DatasetPermissionEnum.ONLY_ME
dataset.permission = permission or DatasetPermissionEnum.ONLY_ME
db.session.add(dataset)
db.session.commit()
return dataset
@@ -681,11 +681,7 @@ class DocumentService:
"score_threshold_enabled": False,
}

dataset.retrieval_model = (
document_data.get("retrieval_model")
if document_data.get("retrieval_model")
else default_retrieval_model
)
dataset.retrieval_model = document_data.get("retrieval_model") or default_retrieval_model

documents = []
batch = time.strftime("%Y%m%d%H%M%S") + str(random.randint(100000, 999999))

+ 2
- 4
api/services/hit_testing_service.py Dosyayı Görüntüle

@@ -33,7 +33,7 @@ class HitTestingService:

# get retrieval model , if the model is not setting , using default
if not retrieval_model:
retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model
retrieval_model = dataset.retrieval_model or default_retrieval_model

all_documents = RetrievalService.retrieve(
retrieval_method=retrieval_model.get("search_method", "semantic_search"),
@@ -46,9 +46,7 @@ class HitTestingService:
reranking_model=retrieval_model.get("reranking_model", None)
if retrieval_model["reranking_enable"]
else None,
reranking_mode=retrieval_model.get("reranking_mode")
if retrieval_model.get("reranking_mode")
else "reranking_model",
reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model",
weights=retrieval_model.get("weights", None),
)


+ 3
- 3
api/services/model_provider_service.py Dosyayı Görüntüle

@@ -1,6 +1,7 @@
import logging
import mimetypes
import os
from pathlib import Path
from typing import Optional, cast

import requests
@@ -453,9 +454,8 @@ class ModelProviderService:
mimetype = mimetype or "application/octet-stream"

# read binary from file
with open(file_path, "rb") as f:
byte_data = f.read()
return byte_data, mimetype
byte_data = Path(file_path).read_bytes()
return byte_data, mimetype

def switch_preferred_provider(self, tenant_id: str, provider: str, preferred_provider_type: str) -> None:
"""

+ 0
- 0
api/services/recommended_app_service.py Dosyayı Görüntüle


Bu fark içinde çok fazla dosya değişikliği olduğu için bazı dosyalar gösterilmiyor

Loading…
İptal
Kaydet