Browse Source

fix: miss usage of os.path.join for URL assembly and add tests on yarl (#4224)

tags/0.6.8
Bowen Liang 1 year ago
parent
commit
228de1f12a
No account linked to committer's email address

+ 2
- 2
api/core/model_runtime/model_providers/chatglm/llm/llm.py View File

@@ -1,6 +1,5 @@
import logging
from collections.abc import Generator
from os.path import join
from typing import Optional, cast

from httpx import Timeout
@@ -19,6 +18,7 @@ from openai import (
)
from openai.types.chat import ChatCompletion, ChatCompletionChunk
from openai.types.chat.chat_completion_message import FunctionCall
from yarl import URL

from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import (
@@ -265,7 +265,7 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel):
client_kwargs = {
"timeout": Timeout(315.0, read=300.0, write=10.0, connect=5.0),
"api_key": "1",
"base_url": join(credentials['api_base'], 'v1')
"base_url": str(URL(credentials['api_base']) / 'v1')
}

return client_kwargs

+ 2
- 2
api/core/tools/provider/builtin/dalle/tools/dalle2.py View File

@@ -1,8 +1,8 @@
from base64 import b64decode
from os.path import join
from typing import Any, Union

from openai import OpenAI
from yarl import URL

from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool
@@ -23,7 +23,7 @@ class DallE2Tool(BuiltinTool):
if not openai_base_url:
openai_base_url = None
else:
openai_base_url = join(openai_base_url, 'v1')
openai_base_url = str(URL(openai_base_url) / 'v1')

client = OpenAI(
api_key=self.runtime.credentials['openai_api_key'],

+ 2
- 2
api/core/tools/provider/builtin/dalle/tools/dalle3.py View File

@@ -1,8 +1,8 @@
from base64 import b64decode
from os.path import join
from typing import Any, Union

from openai import OpenAI
from yarl import URL

from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool
@@ -23,7 +23,7 @@ class DallE3Tool(BuiltinTool):
if not openai_base_url:
openai_base_url = None
else:
openai_base_url = join(openai_base_url, 'v1')
openai_base_url = str(URL(openai_base_url) / 'v1')

client = OpenAI(
api_key=self.runtime.credentials['openai_api_key'],

+ 23
- 0
api/tests/unit_tests/libs/test_yarl.py View File

@@ -0,0 +1,23 @@
import pytest
from yarl import URL


def test_yarl_urls():
expected_1 = 'https://dify.ai/api'
assert str(URL('https://dify.ai') / 'api') == expected_1
assert str(URL('https://dify.ai/') / 'api') == expected_1

expected_2 = 'http://dify.ai:12345/api'
assert str(URL('http://dify.ai:12345') / 'api') == expected_2
assert str(URL('http://dify.ai:12345/') / 'api') == expected_2

expected_3 = 'https://dify.ai/api/v1'
assert str(URL('https://dify.ai') / 'api' / 'v1') == expected_3
assert str(URL('https://dify.ai') / 'api/v1') == expected_3
assert str(URL('https://dify.ai/') / 'api/v1') == expected_3
assert str(URL('https://dify.ai/api') / 'v1') == expected_3
assert str(URL('https://dify.ai/api/') / 'v1') == expected_3

with pytest.raises(ValueError) as e1:
str(URL('https://dify.ai') / '/api')
assert str(e1.value) == "Appending path '/api' starting from slash is forbidden"

Loading…
Cancel
Save