| @@ -1,4 +1,3 @@ | |||
| # -*- coding:utf-8 -*- | |||
| import os | |||
| from werkzeug.exceptions import Unauthorized | |||
| @@ -1,4 +1,3 @@ | |||
| # -*- coding:utf-8 -*- | |||
| import os | |||
| import dotenv | |||
| @@ -1,4 +1,3 @@ | |||
| # -*- coding:utf-8 -*- | |||
| import json | |||
| import logging | |||
| from datetime import datetime | |||
| @@ -1,4 +1,3 @@ | |||
| # -*- coding:utf-8 -*- | |||
| import logging | |||
| from flask import request | |||
| @@ -1,7 +1,7 @@ | |||
| # -*- coding:utf-8 -*- | |||
| import json | |||
| import logging | |||
| from typing import Generator, Union | |||
| from collections.abc import Generator | |||
| from typing import Union | |||
| import flask_login | |||
| from flask import Response, stream_with_context | |||
| @@ -169,8 +169,7 @@ def compact_response(response: Union[dict, Generator]) -> Response: | |||
| return Response(response=json.dumps(response), status=200, mimetype='application/json') | |||
| else: | |||
| def generate() -> Generator: | |||
| for chunk in response: | |||
| yield chunk | |||
| yield from response | |||
| return Response(stream_with_context(generate()), status=200, | |||
| mimetype='text/event-stream') | |||
| @@ -1,6 +1,7 @@ | |||
| import json | |||
| import logging | |||
| from typing import Generator, Union | |||
| from collections.abc import Generator | |||
| from typing import Union | |||
| from flask import Response, stream_with_context | |||
| from flask_login import current_user | |||
| @@ -246,8 +247,7 @@ def compact_response(response: Union[dict, Generator]) -> Response: | |||
| return Response(response=json.dumps(response), status=200, mimetype='application/json') | |||
| else: | |||
| def generate() -> Generator: | |||
| for chunk in response: | |||
| yield chunk | |||
| yield from response | |||
| return Response(stream_with_context(generate()), status=200, | |||
| mimetype='text/event-stream') | |||
| @@ -1,4 +1,3 @@ | |||
| # -*- coding:utf-8 -*- | |||
| from flask import request | |||
| from flask_login import current_user | |||
| @@ -1,4 +1,3 @@ | |||
| # -*- coding:utf-8 -*- | |||
| from flask_login import current_user | |||
| from flask_restful import Resource, marshal_with, reqparse | |||
| from werkzeug.exceptions import Forbidden, NotFound | |||
| @@ -1,4 +1,3 @@ | |||
| # -*- coding:utf-8 -*- | |||
| from datetime import datetime | |||
| from decimal import Decimal | |||
| @@ -1,4 +1,3 @@ | |||
| # -*- coding:utf-8 -*- | |||
| import flask_login | |||
| from flask import current_app, request | |||
| from flask_restful import Resource, reqparse | |||
| @@ -1,4 +1,3 @@ | |||
| # -*- coding:utf-8 -*- | |||
| import flask_restful | |||
| from flask import current_app, request | |||
| from flask_login import current_user | |||
| @@ -1,6 +1,4 @@ | |||
| # -*- coding:utf-8 -*- | |||
| from datetime import datetime | |||
| from typing import List | |||
| from flask import request | |||
| from flask_login import current_user | |||
| @@ -71,7 +69,7 @@ class DocumentResource(Resource): | |||
| return document | |||
| def get_batch_documents(self, dataset_id: str, batch: str) -> List[Document]: | |||
| def get_batch_documents(self, dataset_id: str, batch: str) -> list[Document]: | |||
| dataset = DatasetService.get_dataset(dataset_id) | |||
| if not dataset: | |||
| raise NotFound('Dataset not found.') | |||
| @@ -1,4 +1,3 @@ | |||
| # -*- coding:utf-8 -*- | |||
| import uuid | |||
| from datetime import datetime | |||
| @@ -1,4 +1,3 @@ | |||
| # -*- coding:utf-8 -*- | |||
| import logging | |||
| from flask import request | |||
| @@ -1,8 +1,8 @@ | |||
| # -*- coding:utf-8 -*- | |||
| import json | |||
| import logging | |||
| from collections.abc import Generator | |||
| from datetime import datetime | |||
| from typing import Generator, Union | |||
| from typing import Union | |||
| from flask import Response, stream_with_context | |||
| from flask_login import current_user | |||
| @@ -164,8 +164,7 @@ def compact_response(response: Union[dict, Generator]) -> Response: | |||
| return Response(response=json.dumps(response), status=200, mimetype='application/json') | |||
| else: | |||
| def generate() -> Generator: | |||
| for chunk in response: | |||
| yield chunk | |||
| yield from response | |||
| return Response(stream_with_context(generate()), status=200, | |||
| mimetype='text/event-stream') | |||
| @@ -1,4 +1,3 @@ | |||
| # -*- coding:utf-8 -*- | |||
| from flask_login import current_user | |||
| from flask_restful import marshal_with, reqparse | |||
| from flask_restful.inputs import int_range | |||
| @@ -1,4 +1,3 @@ | |||
| # -*- coding:utf-8 -*- | |||
| from libs.exception import BaseHTTPException | |||
| @@ -1,4 +1,3 @@ | |||
| # -*- coding:utf-8 -*- | |||
| from datetime import datetime | |||
| from flask_login import current_user | |||
| @@ -1,7 +1,7 @@ | |||
| # -*- coding:utf-8 -*- | |||
| import json | |||
| import logging | |||
| from typing import Generator, Union | |||
| from collections.abc import Generator | |||
| from typing import Union | |||
| from flask import Response, stream_with_context | |||
| from flask_login import current_user | |||
| @@ -123,8 +123,7 @@ def compact_response(response: Union[dict, Generator]) -> Response: | |||
| return Response(response=json.dumps(response), status=200, mimetype='application/json') | |||
| else: | |||
| def generate() -> Generator: | |||
| for chunk in response: | |||
| yield chunk | |||
| yield from response | |||
| return Response(stream_with_context(generate()), status=200, | |||
| mimetype='text/event-stream') | |||
| @@ -1,4 +1,3 @@ | |||
| # -*- coding:utf-8 -*- | |||
| import json | |||
| from flask import current_app | |||
| @@ -1,4 +1,3 @@ | |||
| # -*- coding:utf-8 -*- | |||
| from flask_login import current_user | |||
| from flask_restful import Resource, fields, marshal_with | |||
| from sqlalchemy import and_ | |||
| @@ -1,4 +1,3 @@ | |||
| # -*- coding:utf-8 -*- | |||
| from functools import wraps | |||
| from flask import current_app, request | |||
| @@ -1,4 +1,3 @@ | |||
| # -*- coding:utf-8 -*- | |||
| import json | |||
| import logging | |||
| @@ -1,4 +1,3 @@ | |||
| # -*- coding:utf-8 -*- | |||
| from datetime import datetime | |||
| import pytz | |||
| @@ -1,4 +1,3 @@ | |||
| # -*- coding:utf-8 -*- | |||
| from flask import current_app | |||
| from flask_login import current_user | |||
| from flask_restful import Resource, abort, fields, marshal_with, reqparse | |||
| @@ -1,4 +1,3 @@ | |||
| # -*- coding:utf-8 -*- | |||
| import logging | |||
| from flask import request | |||
| @@ -1,4 +1,3 @@ | |||
| # -*- coding:utf-8 -*- | |||
| import json | |||
| from functools import wraps | |||
| @@ -1,4 +1,3 @@ | |||
| # -*- coding:utf-8 -*- | |||
| import json | |||
| from flask import current_app | |||
| @@ -1,6 +1,7 @@ | |||
| import json | |||
| import logging | |||
| from typing import Generator, Union | |||
| from collections.abc import Generator | |||
| from typing import Union | |||
| from flask import Response, stream_with_context | |||
| from flask_restful import reqparse | |||
| @@ -182,8 +183,7 @@ def compact_response(response: Union[dict, Generator]) -> Response: | |||
| return Response(response=json.dumps(response), status=200, mimetype='application/json') | |||
| else: | |||
| def generate() -> Generator: | |||
| for chunk in response: | |||
| yield chunk | |||
| yield from response | |||
| return Response(stream_with_context(generate()), status=200, | |||
| mimetype='text/event-stream') | |||
| @@ -1,4 +1,3 @@ | |||
| # -*- coding:utf-8 -*- | |||
| from flask import request | |||
| from flask_restful import marshal_with, reqparse | |||
| from flask_restful.inputs import int_range | |||
| @@ -1,4 +1,3 @@ | |||
| # -*- coding:utf-8 -*- | |||
| from libs.exception import BaseHTTPException | |||
| @@ -1,4 +1,3 @@ | |||
| # -*- coding:utf-8 -*- | |||
| from flask_restful import fields, marshal_with, reqparse | |||
| from flask_restful.inputs import int_range | |||
| from werkzeug.exceptions import NotFound | |||
| @@ -1,4 +1,3 @@ | |||
| # -*- coding:utf-8 -*- | |||
| from datetime import datetime | |||
| from functools import wraps | |||
| @@ -1,4 +1,3 @@ | |||
| # -*- coding:utf-8 -*- | |||
| import json | |||
| from flask import current_app | |||
| @@ -1,4 +1,3 @@ | |||
| # -*- coding:utf-8 -*- | |||
| import logging | |||
| from flask import request | |||
| @@ -1,7 +1,7 @@ | |||
| # -*- coding:utf-8 -*- | |||
| import json | |||
| import logging | |||
| from typing import Generator, Union | |||
| from collections.abc import Generator | |||
| from typing import Union | |||
| from flask import Response, stream_with_context | |||
| from flask_restful import reqparse | |||
| @@ -154,8 +154,7 @@ def compact_response(response: Union[dict, Generator]) -> Response: | |||
| return Response(response=json.dumps(response), status=200, mimetype='application/json') | |||
| else: | |||
| def generate() -> Generator: | |||
| for chunk in response: | |||
| yield chunk | |||
| yield from response | |||
| return Response(stream_with_context(generate()), status=200, | |||
| mimetype='text/event-stream') | |||
| @@ -1,4 +1,3 @@ | |||
| # -*- coding:utf-8 -*- | |||
| from flask_restful import marshal_with, reqparse | |||
| from flask_restful.inputs import int_range | |||
| from werkzeug.exceptions import NotFound | |||
| @@ -1,4 +1,3 @@ | |||
| # -*- coding:utf-8 -*- | |||
| from libs.exception import BaseHTTPException | |||
| @@ -1,7 +1,7 @@ | |||
| # -*- coding:utf-8 -*- | |||
| import json | |||
| import logging | |||
| from typing import Generator, Union | |||
| from collections.abc import Generator | |||
| from typing import Union | |||
| from flask import Response, stream_with_context | |||
| from flask_restful import fields, marshal_with, reqparse | |||
| @@ -160,8 +160,7 @@ def compact_response(response: Union[dict, Generator]) -> Response: | |||
| return Response(response=json.dumps(response), status=200, mimetype='application/json') | |||
| else: | |||
| def generate() -> Generator: | |||
| for chunk in response: | |||
| yield chunk | |||
| yield from response | |||
| return Response(stream_with_context(generate()), status=200, | |||
| mimetype='text/event-stream') | |||
| @@ -1,4 +1,3 @@ | |||
| # -*- coding:utf-8 -*- | |||
| import uuid | |||
| from flask import request | |||
| @@ -1,4 +1,3 @@ | |||
| # -*- coding:utf-8 -*- | |||
| from flask import current_app | |||
| from flask_restful import fields, marshal_with | |||
| @@ -1,4 +1,3 @@ | |||
| # -*- coding:utf-8 -*- | |||
| from functools import wraps | |||
| from flask import request | |||
| @@ -1,5 +1,5 @@ | |||
| import logging | |||
| from typing import List, Optional | |||
| from typing import Optional | |||
| from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler | |||
| from core.model_runtime.callbacks.base_callback import Callback | |||
| @@ -17,7 +17,7 @@ class AgentLLMCallback(Callback): | |||
| def on_before_invoke(self, llm_instance: AIModel, model: str, credentials: dict, | |||
| prompt_messages: list[PromptMessage], model_parameters: dict, | |||
| tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, | |||
| tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, | |||
| stream: bool = True, user: Optional[str] = None) -> None: | |||
| """ | |||
| Before invoke callback | |||
| @@ -38,7 +38,7 @@ class AgentLLMCallback(Callback): | |||
| def on_new_chunk(self, llm_instance: AIModel, chunk: LLMResultChunk, model: str, credentials: dict, | |||
| prompt_messages: list[PromptMessage], model_parameters: dict, | |||
| tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, | |||
| tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, | |||
| stream: bool = True, user: Optional[str] = None): | |||
| """ | |||
| On new chunk callback | |||
| @@ -58,7 +58,7 @@ class AgentLLMCallback(Callback): | |||
| def on_after_invoke(self, llm_instance: AIModel, result: LLMResult, model: str, credentials: dict, | |||
| prompt_messages: list[PromptMessage], model_parameters: dict, | |||
| tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, | |||
| tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, | |||
| stream: bool = True, user: Optional[str] = None) -> None: | |||
| """ | |||
| After invoke callback | |||
| @@ -80,7 +80,7 @@ class AgentLLMCallback(Callback): | |||
| def on_invoke_error(self, llm_instance: AIModel, ex: Exception, model: str, credentials: dict, | |||
| prompt_messages: list[PromptMessage], model_parameters: dict, | |||
| tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, | |||
| tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, | |||
| stream: bool = True, user: Optional[str] = None) -> None: | |||
| """ | |||
| Invoke error callback | |||
| @@ -1,4 +1,4 @@ | |||
| from typing import List, cast | |||
| from typing import cast | |||
| from core.entities.application_entities import ModelConfigEntity | |||
| from core.model_runtime.entities.message_entities import PromptMessage | |||
| @@ -8,7 +8,7 @@ from core.model_runtime.model_providers.__base.large_language_model import Large | |||
| class CalcTokenMixin: | |||
| def get_message_rest_tokens(self, model_config: ModelConfigEntity, messages: List[PromptMessage], **kwargs) -> int: | |||
| def get_message_rest_tokens(self, model_config: ModelConfigEntity, messages: list[PromptMessage], **kwargs) -> int: | |||
| """ | |||
| Got the rest tokens available for the model after excluding messages tokens and completion max tokens | |||
| @@ -1,4 +1,5 @@ | |||
| from typing import Any, List, Optional, Sequence, Tuple, Union | |||
| from collections.abc import Sequence | |||
| from typing import Any, Optional, Union | |||
| from langchain.agents import BaseSingleActionAgent, OpenAIFunctionsAgent | |||
| from langchain.agents.openai_functions_agent.base import _format_intermediate_steps, _parse_ai_message | |||
| @@ -42,7 +43,7 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent): | |||
| def plan( | |||
| self, | |||
| intermediate_steps: List[Tuple[AgentAction, str]], | |||
| intermediate_steps: list[tuple[AgentAction, str]], | |||
| callbacks: Callbacks = None, | |||
| **kwargs: Any, | |||
| ) -> Union[AgentAction, AgentFinish]: | |||
| @@ -85,7 +86,7 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent): | |||
| def real_plan( | |||
| self, | |||
| intermediate_steps: List[Tuple[AgentAction, str]], | |||
| intermediate_steps: list[tuple[AgentAction, str]], | |||
| callbacks: Callbacks = None, | |||
| **kwargs: Any, | |||
| ) -> Union[AgentAction, AgentFinish]: | |||
| @@ -146,7 +147,7 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent): | |||
| async def aplan( | |||
| self, | |||
| intermediate_steps: List[Tuple[AgentAction, str]], | |||
| intermediate_steps: list[tuple[AgentAction, str]], | |||
| callbacks: Callbacks = None, | |||
| **kwargs: Any, | |||
| ) -> Union[AgentAction, AgentFinish]: | |||
| @@ -158,7 +159,7 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent): | |||
| model_config: ModelConfigEntity, | |||
| tools: Sequence[BaseTool], | |||
| callback_manager: Optional[BaseCallbackManager] = None, | |||
| extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None, | |||
| extra_prompt_messages: Optional[list[BaseMessagePromptTemplate]] = None, | |||
| system_message: Optional[SystemMessage] = SystemMessage( | |||
| content="You are a helpful AI assistant." | |||
| ), | |||
| @@ -1,4 +1,5 @@ | |||
| from typing import Any, List, Optional, Sequence, Tuple, Union | |||
| from collections.abc import Sequence | |||
| from typing import Any, Optional, Union | |||
| from langchain.agents import BaseSingleActionAgent, OpenAIFunctionsAgent | |||
| from langchain.agents.openai_functions_agent.base import _format_intermediate_steps, _parse_ai_message | |||
| @@ -51,7 +52,7 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixi | |||
| model_config: ModelConfigEntity, | |||
| tools: Sequence[BaseTool], | |||
| callback_manager: Optional[BaseCallbackManager] = None, | |||
| extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None, | |||
| extra_prompt_messages: Optional[list[BaseMessagePromptTemplate]] = None, | |||
| system_message: Optional[SystemMessage] = SystemMessage( | |||
| content="You are a helpful AI assistant." | |||
| ), | |||
| @@ -125,7 +126,7 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixi | |||
| def plan( | |||
| self, | |||
| intermediate_steps: List[Tuple[AgentAction, str]], | |||
| intermediate_steps: list[tuple[AgentAction, str]], | |||
| callbacks: Callbacks = None, | |||
| **kwargs: Any, | |||
| ) -> Union[AgentAction, AgentFinish]: | |||
| @@ -207,7 +208,7 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixi | |||
| def return_stopped_response( | |||
| self, | |||
| early_stopping_method: str, | |||
| intermediate_steps: List[Tuple[AgentAction, str]], | |||
| intermediate_steps: list[tuple[AgentAction, str]], | |||
| **kwargs: Any, | |||
| ) -> AgentFinish: | |||
| try: | |||
| @@ -215,7 +216,7 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixi | |||
| except ValueError: | |||
| return AgentFinish({"output": "I'm sorry, I don't know how to respond to that."}, "") | |||
| def summarize_messages_if_needed(self, messages: List[PromptMessage], **kwargs) -> List[PromptMessage]: | |||
| def summarize_messages_if_needed(self, messages: list[PromptMessage], **kwargs) -> list[PromptMessage]: | |||
| # calculate rest tokens and summarize previous function observation messages if rest_tokens < 0 | |||
| rest_tokens = self.get_message_rest_tokens( | |||
| self.model_config, | |||
| @@ -264,7 +265,7 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixi | |||
| return new_messages | |||
| def predict_new_summary( | |||
| self, messages: List[BaseMessage], existing_summary: str | |||
| self, messages: list[BaseMessage], existing_summary: str | |||
| ) -> str: | |||
| new_lines = get_buffer_string( | |||
| messages, | |||
| @@ -275,7 +276,7 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixi | |||
| chain = LLMChain(model_config=self.summary_model_config, prompt=SUMMARY_PROMPT) | |||
| return chain.predict(summary=existing_summary, new_lines=new_lines) | |||
| def get_num_tokens_from_messages(self, model_config: ModelConfigEntity, messages: List[BaseMessage], **kwargs) -> int: | |||
| def get_num_tokens_from_messages(self, model_config: ModelConfigEntity, messages: list[BaseMessage], **kwargs) -> int: | |||
| """Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package. | |||
| Official documentation: https://github.com/openai/openai-cookbook/blob/ | |||
| @@ -1,5 +1,6 @@ | |||
| import re | |||
| from typing import Any, List, Optional, Sequence, Tuple, Union, cast | |||
| from collections.abc import Sequence | |||
| from typing import Any, Optional, Union, cast | |||
| from langchain import BasePromptTemplate, PromptTemplate | |||
| from langchain.agents import Agent, AgentOutputParser, StructuredChatAgent | |||
| @@ -68,7 +69,7 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent): | |||
| def plan( | |||
| self, | |||
| intermediate_steps: List[Tuple[AgentAction, str]], | |||
| intermediate_steps: list[tuple[AgentAction, str]], | |||
| callbacks: Callbacks = None, | |||
| **kwargs: Any, | |||
| ) -> Union[AgentAction, AgentFinish]: | |||
| @@ -125,8 +126,8 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent): | |||
| suffix: str = SUFFIX, | |||
| human_message_template: str = HUMAN_MESSAGE_TEMPLATE, | |||
| format_instructions: str = FORMAT_INSTRUCTIONS, | |||
| input_variables: Optional[List[str]] = None, | |||
| memory_prompts: Optional[List[BasePromptTemplate]] = None, | |||
| input_variables: Optional[list[str]] = None, | |||
| memory_prompts: Optional[list[BasePromptTemplate]] = None, | |||
| ) -> BasePromptTemplate: | |||
| tool_strings = [] | |||
| for tool in tools: | |||
| @@ -153,7 +154,7 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent): | |||
| tools: Sequence[BaseTool], | |||
| prefix: str = PREFIX, | |||
| format_instructions: str = FORMAT_INSTRUCTIONS, | |||
| input_variables: Optional[List[str]] = None, | |||
| input_variables: Optional[list[str]] = None, | |||
| ) -> PromptTemplate: | |||
| """Create prompt in the style of the zero shot agent. | |||
| @@ -180,7 +181,7 @@ Thought: {agent_scratchpad} | |||
| return PromptTemplate(template=template, input_variables=input_variables) | |||
| def _construct_scratchpad( | |||
| self, intermediate_steps: List[Tuple[AgentAction, str]] | |||
| self, intermediate_steps: list[tuple[AgentAction, str]] | |||
| ) -> str: | |||
| agent_scratchpad = "" | |||
| for action, observation in intermediate_steps: | |||
| @@ -213,8 +214,8 @@ Thought: {agent_scratchpad} | |||
| suffix: str = SUFFIX, | |||
| human_message_template: str = HUMAN_MESSAGE_TEMPLATE, | |||
| format_instructions: str = FORMAT_INSTRUCTIONS, | |||
| input_variables: Optional[List[str]] = None, | |||
| memory_prompts: Optional[List[BasePromptTemplate]] = None, | |||
| input_variables: Optional[list[str]] = None, | |||
| memory_prompts: Optional[list[BasePromptTemplate]] = None, | |||
| **kwargs: Any, | |||
| ) -> Agent: | |||
| """Construct an agent from an LLM and tools.""" | |||
| @@ -1,5 +1,6 @@ | |||
| import re | |||
| from typing import Any, List, Optional, Sequence, Tuple, Union, cast | |||
| from collections.abc import Sequence | |||
| from typing import Any, Optional, Union, cast | |||
| from langchain import BasePromptTemplate, PromptTemplate | |||
| from langchain.agents import Agent, AgentOutputParser, StructuredChatAgent | |||
| @@ -82,7 +83,7 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin): | |||
| def plan( | |||
| self, | |||
| intermediate_steps: List[Tuple[AgentAction, str]], | |||
| intermediate_steps: list[tuple[AgentAction, str]], | |||
| callbacks: Callbacks = None, | |||
| **kwargs: Any, | |||
| ) -> Union[AgentAction, AgentFinish]: | |||
| @@ -127,7 +128,7 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin): | |||
| return AgentFinish({"output": "I'm sorry, the answer of model is invalid, " | |||
| "I don't know how to respond to that."}, "") | |||
| def summarize_messages(self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs): | |||
| def summarize_messages(self, intermediate_steps: list[tuple[AgentAction, str]], **kwargs): | |||
| if len(intermediate_steps) >= 2 and self.summary_model_config: | |||
| should_summary_intermediate_steps = intermediate_steps[self.moving_summary_index:-1] | |||
| should_summary_messages = [AIMessage(content=observation) | |||
| @@ -154,7 +155,7 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin): | |||
| return self.get_full_inputs([intermediate_steps[-1]], **kwargs) | |||
| def predict_new_summary( | |||
| self, messages: List[BaseMessage], existing_summary: str | |||
| self, messages: list[BaseMessage], existing_summary: str | |||
| ) -> str: | |||
| new_lines = get_buffer_string( | |||
| messages, | |||
| @@ -173,8 +174,8 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin): | |||
| suffix: str = SUFFIX, | |||
| human_message_template: str = HUMAN_MESSAGE_TEMPLATE, | |||
| format_instructions: str = FORMAT_INSTRUCTIONS, | |||
| input_variables: Optional[List[str]] = None, | |||
| memory_prompts: Optional[List[BasePromptTemplate]] = None, | |||
| input_variables: Optional[list[str]] = None, | |||
| memory_prompts: Optional[list[BasePromptTemplate]] = None, | |||
| ) -> BasePromptTemplate: | |||
| tool_strings = [] | |||
| for tool in tools: | |||
| @@ -200,7 +201,7 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin): | |||
| tools: Sequence[BaseTool], | |||
| prefix: str = PREFIX, | |||
| format_instructions: str = FORMAT_INSTRUCTIONS, | |||
| input_variables: Optional[List[str]] = None, | |||
| input_variables: Optional[list[str]] = None, | |||
| ) -> PromptTemplate: | |||
| """Create prompt in the style of the zero shot agent. | |||
| @@ -227,7 +228,7 @@ Thought: {agent_scratchpad} | |||
| return PromptTemplate(template=template, input_variables=input_variables) | |||
| def _construct_scratchpad( | |||
| self, intermediate_steps: List[Tuple[AgentAction, str]] | |||
| self, intermediate_steps: list[tuple[AgentAction, str]] | |||
| ) -> str: | |||
| agent_scratchpad = "" | |||
| for action, observation in intermediate_steps: | |||
| @@ -260,8 +261,8 @@ Thought: {agent_scratchpad} | |||
| suffix: str = SUFFIX, | |||
| human_message_template: str = HUMAN_MESSAGE_TEMPLATE, | |||
| format_instructions: str = FORMAT_INSTRUCTIONS, | |||
| input_variables: Optional[List[str]] = None, | |||
| memory_prompts: Optional[List[BasePromptTemplate]] = None, | |||
| input_variables: Optional[list[str]] = None, | |||
| memory_prompts: Optional[list[BasePromptTemplate]] = None, | |||
| agent_llm_callback: Optional[AgentLLMCallback] = None, | |||
| **kwargs: Any, | |||
| ) -> Agent: | |||
| @@ -1,5 +1,6 @@ | |||
| import time | |||
| from typing import Generator, List, Optional, Tuple, Union, cast | |||
| from collections.abc import Generator | |||
| from typing import Optional, Union, cast | |||
| from core.application_queue_manager import ApplicationQueueManager, PublishFrom | |||
| from core.entities.application_entities import ( | |||
| @@ -84,7 +85,7 @@ class AppRunner: | |||
| return rest_tokens | |||
| def recale_llm_max_tokens(self, model_config: ModelConfigEntity, | |||
| prompt_messages: List[PromptMessage]): | |||
| prompt_messages: list[PromptMessage]): | |||
| # recalc max_tokens if sum(prompt_token + max_tokens) over model token limit | |||
| model_type_instance = model_config.provider_model_bundle.model_type_instance | |||
| model_type_instance = cast(LargeLanguageModel, model_type_instance) | |||
| @@ -126,7 +127,7 @@ class AppRunner: | |||
| query: Optional[str] = None, | |||
| context: Optional[str] = None, | |||
| memory: Optional[TokenBufferMemory] = None) \ | |||
| -> Tuple[List[PromptMessage], Optional[List[str]]]: | |||
| -> tuple[list[PromptMessage], Optional[list[str]]]: | |||
| """ | |||
| Organize prompt messages | |||
| :param context: | |||
| @@ -295,7 +296,7 @@ class AppRunner: | |||
| tenant_id: str, | |||
| app_orchestration_config_entity: AppOrchestrationConfigEntity, | |||
| inputs: dict, | |||
| query: str) -> Tuple[bool, dict, str]: | |||
| query: str) -> tuple[bool, dict, str]: | |||
| """ | |||
| Process sensitive_word_avoidance. | |||
| :param app_id: app id | |||
| @@ -1,7 +1,8 @@ | |||
| import json | |||
| import logging | |||
| import time | |||
| from typing import Generator, Optional, Union, cast | |||
| from collections.abc import Generator | |||
| from typing import Optional, Union, cast | |||
| from pydantic import BaseModel | |||
| @@ -118,7 +119,7 @@ class GenerateTaskPipeline: | |||
| } | |||
| self._task_state.llm_result.message.content = annotation.content | |||
| elif isinstance(event, (QueueStopEvent, QueueMessageEndEvent)): | |||
| elif isinstance(event, QueueStopEvent | QueueMessageEndEvent): | |||
| if isinstance(event, QueueMessageEndEvent): | |||
| self._task_state.llm_result = event.llm_result | |||
| else: | |||
| @@ -201,7 +202,7 @@ class GenerateTaskPipeline: | |||
| data = self._error_to_stream_response_data(self._handle_error(event)) | |||
| yield self._yield_response(data) | |||
| break | |||
| elif isinstance(event, (QueueStopEvent, QueueMessageEndEvent)): | |||
| elif isinstance(event, QueueStopEvent | QueueMessageEndEvent): | |||
| if isinstance(event, QueueMessageEndEvent): | |||
| self._task_state.llm_result = event.llm_result | |||
| else: | |||
| @@ -353,7 +354,7 @@ class GenerateTaskPipeline: | |||
| yield self._yield_response(response) | |||
| elif isinstance(event, (QueueMessageEvent, QueueAgentMessageEvent)): | |||
| elif isinstance(event, QueueMessageEvent | QueueAgentMessageEvent): | |||
| chunk = event.chunk | |||
| delta_text = chunk.delta.message.content | |||
| if delta_text is None: | |||
| @@ -1,7 +1,7 @@ | |||
| import logging | |||
| import threading | |||
| import time | |||
| from typing import Any, Dict, Optional | |||
| from typing import Any, Optional | |||
| from flask import Flask, current_app | |||
| from pydantic import BaseModel | |||
| @@ -15,7 +15,7 @@ logger = logging.getLogger(__name__) | |||
| class ModerationRule(BaseModel): | |||
| type: str | |||
| config: Dict[str, Any] | |||
| config: dict[str, Any] | |||
| class OutputModerationHandler(BaseModel): | |||
| @@ -2,7 +2,8 @@ import json | |||
| import logging | |||
| import threading | |||
| import uuid | |||
| from typing import Any, Generator, Optional, Tuple, Union, cast | |||
| from collections.abc import Generator | |||
| from typing import Any, Optional, Union, cast | |||
| from flask import Flask, current_app | |||
| from pydantic import ValidationError | |||
| @@ -585,7 +586,7 @@ class ApplicationManager: | |||
| return AppOrchestrationConfigEntity(**properties) | |||
| def _init_generate_records(self, application_generate_entity: ApplicationGenerateEntity) \ | |||
| -> Tuple[Conversation, Message]: | |||
| -> tuple[Conversation, Message]: | |||
| """ | |||
| Initialize generate records | |||
| :param application_generate_entity: application generate entity | |||
| @@ -1,7 +1,8 @@ | |||
| import queue | |||
| import time | |||
| from collections.abc import Generator | |||
| from enum import Enum | |||
| from typing import Any, Generator | |||
| from typing import Any | |||
| from sqlalchemy.orm import DeclarativeMeta | |||
| @@ -1,7 +1,7 @@ | |||
| import json | |||
| import logging | |||
| import time | |||
| from typing import Any, Dict, List, Optional, Union, cast | |||
| from typing import Any, Optional, Union, cast | |||
| from langchain.agents import openai_functions_agent, openai_functions_multi_agent | |||
| from langchain.callbacks.base import BaseCallbackHandler | |||
| @@ -37,7 +37,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler): | |||
| self._message_agent_thought = None | |||
| @property | |||
| def agent_loops(self) -> List[AgentLoop]: | |||
| def agent_loops(self) -> list[AgentLoop]: | |||
| return self._agent_loops | |||
| def clear_agent_loops(self) -> None: | |||
| @@ -95,14 +95,14 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler): | |||
| def on_chat_model_start( | |||
| self, | |||
| serialized: Dict[str, Any], | |||
| messages: List[List[BaseMessage]], | |||
| serialized: dict[str, Any], | |||
| messages: list[list[BaseMessage]], | |||
| **kwargs: Any | |||
| ) -> Any: | |||
| pass | |||
| def on_llm_start( | |||
| self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any | |||
| self, serialized: dict[str, Any], prompts: list[str], **kwargs: Any | |||
| ) -> None: | |||
| pass | |||
| @@ -120,7 +120,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler): | |||
| def on_tool_start( | |||
| self, | |||
| serialized: Dict[str, Any], | |||
| serialized: dict[str, Any], | |||
| input_str: str, | |||
| **kwargs: Any, | |||
| ) -> None: | |||
| @@ -1,5 +1,5 @@ | |||
| import os | |||
| from typing import Any, Dict, Optional, Union | |||
| from typing import Any, Optional, Union | |||
| from langchain.callbacks.base import BaseCallbackHandler | |||
| from langchain.input import print_text | |||
| @@ -21,7 +21,7 @@ class DifyAgentCallbackHandler(BaseCallbackHandler, BaseModel): | |||
| def on_tool_start( | |||
| self, | |||
| tool_name: str, | |||
| tool_inputs: Dict[str, Any], | |||
| tool_inputs: dict[str, Any], | |||
| ) -> None: | |||
| """Do nothing.""" | |||
| print_text("\n[on_tool_start] ToolCall:" + tool_name + "\n" + str(tool_inputs) + "\n", color=self.color) | |||
| @@ -29,7 +29,7 @@ class DifyAgentCallbackHandler(BaseCallbackHandler, BaseModel): | |||
| def on_tool_end( | |||
| self, | |||
| tool_name: str, | |||
| tool_inputs: Dict[str, Any], | |||
| tool_inputs: dict[str, Any], | |||
| tool_outputs: str, | |||
| ) -> None: | |||
| """If not the final action, print out observation.""" | |||
| @@ -1,4 +1,3 @@ | |||
| from typing import List | |||
| from langchain.schema import Document | |||
| @@ -40,7 +39,7 @@ class DatasetIndexToolCallbackHandler: | |||
| db.session.add(dataset_query) | |||
| db.session.commit() | |||
| def on_tool_end(self, documents: List[Document]) -> None: | |||
| def on_tool_end(self, documents: list[Document]) -> None: | |||
| """Handle tool end.""" | |||
| for document in documents: | |||
| doc_id = document.metadata['doc_id'] | |||
| @@ -55,7 +54,7 @@ class DatasetIndexToolCallbackHandler: | |||
| db.session.commit() | |||
| def return_retriever_resource_info(self, resource: List): | |||
| def return_retriever_resource_info(self, resource: list): | |||
| """Handle return_retriever_resource_info.""" | |||
| if resource and len(resource) > 0: | |||
| for item in resource: | |||
| @@ -1,6 +1,6 @@ | |||
| import os | |||
| import sys | |||
| from typing import Any, Dict, List, Optional, Union | |||
| from typing import Any, Optional, Union | |||
| from langchain.callbacks.base import BaseCallbackHandler | |||
| from langchain.input import print_text | |||
| @@ -16,8 +16,8 @@ class DifyStdOutCallbackHandler(BaseCallbackHandler): | |||
| def on_chat_model_start( | |||
| self, | |||
| serialized: Dict[str, Any], | |||
| messages: List[List[BaseMessage]], | |||
| serialized: dict[str, Any], | |||
| messages: list[list[BaseMessage]], | |||
| **kwargs: Any | |||
| ) -> Any: | |||
| print_text("\n[on_chat_model_start]\n", color='blue') | |||
| @@ -26,7 +26,7 @@ class DifyStdOutCallbackHandler(BaseCallbackHandler): | |||
| print_text(str(sub_message) + "\n", color='blue') | |||
| def on_llm_start( | |||
| self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any | |||
| self, serialized: dict[str, Any], prompts: list[str], **kwargs: Any | |||
| ) -> None: | |||
| """Print out the prompts.""" | |||
| print_text("\n[on_llm_start]\n", color='blue') | |||
| @@ -48,13 +48,13 @@ class DifyStdOutCallbackHandler(BaseCallbackHandler): | |||
| print_text("\n[on_llm_error]\nError: " + str(error) + "\n", color='blue') | |||
| def on_chain_start( | |||
| self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any | |||
| self, serialized: dict[str, Any], inputs: dict[str, Any], **kwargs: Any | |||
| ) -> None: | |||
| """Print out that we are entering a chain.""" | |||
| chain_type = serialized['id'][-1] | |||
| print_text("\n[on_chain_start]\nChain: " + chain_type + "\nInputs: " + str(inputs) + "\n", color='pink') | |||
| def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: | |||
| def on_chain_end(self, outputs: dict[str, Any], **kwargs: Any) -> None: | |||
| """Print out that we finished a chain.""" | |||
| print_text("\n[on_chain_end]\nOutputs: " + str(outputs) + "\n", color='pink') | |||
| @@ -66,7 +66,7 @@ class DifyStdOutCallbackHandler(BaseCallbackHandler): | |||
| def on_tool_start( | |||
| self, | |||
| serialized: Dict[str, Any], | |||
| serialized: dict[str, Any], | |||
| input_str: str, | |||
| **kwargs: Any, | |||
| ) -> None: | |||
| @@ -1,4 +1,4 @@ | |||
| from typing import Any, Dict, List, Optional | |||
| from typing import Any, Optional | |||
| from langchain import LLMChain as LCLLMChain | |||
| from langchain.callbacks.manager import CallbackManagerForChainRun | |||
| @@ -16,12 +16,12 @@ class LLMChain(LCLLMChain): | |||
| model_config: ModelConfigEntity | |||
| """The language model instance to use.""" | |||
| llm: BaseLanguageModel = FakeLLM(response="") | |||
| parameters: Dict[str, Any] = {} | |||
| parameters: dict[str, Any] = {} | |||
| agent_llm_callback: Optional[AgentLLMCallback] = None | |||
| def generate( | |||
| self, | |||
| input_list: List[Dict[str, Any]], | |||
| input_list: list[dict[str, Any]], | |||
| run_manager: Optional[CallbackManagerForChainRun] = None, | |||
| ) -> LLMResult: | |||
| """Generate LLM result from inputs.""" | |||
| @@ -1,6 +1,6 @@ | |||
| import tempfile | |||
| from pathlib import Path | |||
| from typing import List, Optional, Union | |||
| from typing import Optional, Union | |||
| import requests | |||
| from flask import current_app | |||
| @@ -28,7 +28,7 @@ USER_AGENT = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTM | |||
| class FileExtractor: | |||
| @classmethod | |||
| def load(cls, upload_file: UploadFile, return_text: bool = False, is_automatic: bool = False) -> Union[List[Document], str]: | |||
| def load(cls, upload_file: UploadFile, return_text: bool = False, is_automatic: bool = False) -> Union[list[Document], str]: | |||
| with tempfile.TemporaryDirectory() as temp_dir: | |||
| suffix = Path(upload_file.key).suffix | |||
| file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" | |||
| @@ -37,7 +37,7 @@ class FileExtractor: | |||
| return cls.load_from_file(file_path, return_text, upload_file, is_automatic) | |||
| @classmethod | |||
| def load_from_url(cls, url: str, return_text: bool = False) -> Union[List[Document], str]: | |||
| def load_from_url(cls, url: str, return_text: bool = False) -> Union[list[Document], str]: | |||
| response = requests.get(url, headers={ | |||
| "User-Agent": USER_AGENT | |||
| }) | |||
| @@ -53,7 +53,7 @@ class FileExtractor: | |||
| @classmethod | |||
| def load_from_file(cls, file_path: str, return_text: bool = False, | |||
| upload_file: Optional[UploadFile] = None, | |||
| is_automatic: bool = False) -> Union[List[Document], str]: | |||
| is_automatic: bool = False) -> Union[list[Document], str]: | |||
| input_file = Path(file_path) | |||
| delimiter = '\n' | |||
| file_extension = input_file.suffix.lower() | |||
| @@ -1,6 +1,6 @@ | |||
| import csv | |||
| import logging | |||
| from typing import Dict, List, Optional | |||
| from typing import Optional | |||
| from langchain.document_loaders import CSVLoader as LCCSVLoader | |||
| from langchain.document_loaders.helpers import detect_file_encodings | |||
| @@ -14,7 +14,7 @@ class CSVLoader(LCCSVLoader): | |||
| self, | |||
| file_path: str, | |||
| source_column: Optional[str] = None, | |||
| csv_args: Optional[Dict] = None, | |||
| csv_args: Optional[dict] = None, | |||
| encoding: Optional[str] = None, | |||
| autodetect_encoding: bool = True, | |||
| ): | |||
| @@ -24,7 +24,7 @@ class CSVLoader(LCCSVLoader): | |||
| self.csv_args = csv_args or {} | |||
| self.autodetect_encoding = autodetect_encoding | |||
| def load(self) -> List[Document]: | |||
| def load(self) -> list[Document]: | |||
| """Load data into document objects.""" | |||
| try: | |||
| with open(self.file_path, newline="", encoding=self.encoding) as csvfile: | |||
| @@ -1,5 +1,4 @@ | |||
| import logging | |||
| from typing import List | |||
| from langchain.document_loaders.base import BaseLoader | |||
| from langchain.schema import Document | |||
| @@ -23,7 +22,7 @@ class ExcelLoader(BaseLoader): | |||
| """Initialize with file path.""" | |||
| self._file_path = file_path | |||
| def load(self) -> List[Document]: | |||
| def load(self) -> list[Document]: | |||
| data = [] | |||
| keys = [] | |||
| wb = load_workbook(filename=self._file_path, read_only=True) | |||
| @@ -1,5 +1,4 @@ | |||
| import logging | |||
| from typing import List | |||
| from bs4 import BeautifulSoup | |||
| from langchain.document_loaders.base import BaseLoader | |||
| @@ -23,7 +22,7 @@ class HTMLLoader(BaseLoader): | |||
| """Initialize with file path.""" | |||
| self._file_path = file_path | |||
| def load(self) -> List[Document]: | |||
| def load(self) -> list[Document]: | |||
| return [Document(page_content=self._load_as_text())] | |||
| def _load_as_text(self) -> str: | |||
| @@ -1,6 +1,6 @@ | |||
| import logging | |||
| import re | |||
| from typing import List, Optional, Tuple, cast | |||
| from typing import Optional, cast | |||
| from langchain.document_loaders.base import BaseLoader | |||
| from langchain.document_loaders.helpers import detect_file_encodings | |||
| @@ -42,7 +42,7 @@ class MarkdownLoader(BaseLoader): | |||
| self._encoding = encoding | |||
| self._autodetect_encoding = autodetect_encoding | |||
| def load(self) -> List[Document]: | |||
| def load(self) -> list[Document]: | |||
| tups = self.parse_tups(self._file_path) | |||
| documents = [] | |||
| for header, value in tups: | |||
| @@ -54,13 +54,13 @@ class MarkdownLoader(BaseLoader): | |||
| return documents | |||
| def markdown_to_tups(self, markdown_text: str) -> List[Tuple[Optional[str], str]]: | |||
| def markdown_to_tups(self, markdown_text: str) -> list[tuple[Optional[str], str]]: | |||
| """Convert a markdown file to a dictionary. | |||
| The keys are the headers and the values are the text under each header. | |||
| """ | |||
| markdown_tups: List[Tuple[Optional[str], str]] = [] | |||
| markdown_tups: list[tuple[Optional[str], str]] = [] | |||
| lines = markdown_text.split("\n") | |||
| current_header = None | |||
| @@ -103,11 +103,11 @@ class MarkdownLoader(BaseLoader): | |||
| content = re.sub(pattern, r"\1", content) | |||
| return content | |||
| def parse_tups(self, filepath: str) -> List[Tuple[Optional[str], str]]: | |||
| def parse_tups(self, filepath: str) -> list[tuple[Optional[str], str]]: | |||
| """Parse file into tuples.""" | |||
| content = "" | |||
| try: | |||
| with open(filepath, "r", encoding=self._encoding) as f: | |||
| with open(filepath, encoding=self._encoding) as f: | |||
| content = f.read() | |||
| except UnicodeDecodeError as e: | |||
| if self._autodetect_encoding: | |||
| @@ -1,6 +1,6 @@ | |||
| import json | |||
| import logging | |||
| from typing import Any, Dict, List, Optional | |||
| from typing import Any, Optional | |||
| import requests | |||
| from flask import current_app | |||
| @@ -67,7 +67,7 @@ class NotionLoader(BaseLoader): | |||
| document_model=document_model | |||
| ) | |||
| def load(self) -> List[Document]: | |||
| def load(self) -> list[Document]: | |||
| self.update_last_edited_time( | |||
| self._document_model | |||
| ) | |||
| @@ -78,7 +78,7 @@ class NotionLoader(BaseLoader): | |||
| def _load_data_as_documents( | |||
| self, notion_obj_id: str, notion_page_type: str | |||
| ) -> List[Document]: | |||
| ) -> list[Document]: | |||
| docs = [] | |||
| if notion_page_type == 'database': | |||
| # get all the pages in the database | |||
| @@ -94,8 +94,8 @@ class NotionLoader(BaseLoader): | |||
| return docs | |||
| def _get_notion_database_data( | |||
| self, database_id: str, query_dict: Dict[str, Any] = {} | |||
| ) -> List[Document]: | |||
| self, database_id: str, query_dict: dict[str, Any] = {} | |||
| ) -> list[Document]: | |||
| """Get all the pages from a Notion database.""" | |||
| res = requests.post( | |||
| DATABASE_URL_TMPL.format(database_id=database_id), | |||
| @@ -149,12 +149,12 @@ class NotionLoader(BaseLoader): | |||
| return database_content_list | |||
| def _get_notion_block_data(self, page_id: str) -> List[str]: | |||
| def _get_notion_block_data(self, page_id: str) -> list[str]: | |||
| result_lines_arr = [] | |||
| cur_block_id = page_id | |||
| while True: | |||
| block_url = BLOCK_CHILD_URL_TMPL.format(block_id=cur_block_id) | |||
| query_dict: Dict[str, Any] = {} | |||
| query_dict: dict[str, Any] = {} | |||
| res = requests.request( | |||
| "GET", | |||
| @@ -216,7 +216,7 @@ class NotionLoader(BaseLoader): | |||
| cur_block_id = block_id | |||
| while True: | |||
| block_url = BLOCK_CHILD_URL_TMPL.format(block_id=cur_block_id) | |||
| query_dict: Dict[str, Any] = {} | |||
| query_dict: dict[str, Any] = {} | |||
| res = requests.request( | |||
| "GET", | |||
| @@ -280,7 +280,7 @@ class NotionLoader(BaseLoader): | |||
| cur_block_id = block_id | |||
| while not done: | |||
| block_url = BLOCK_CHILD_URL_TMPL.format(block_id=cur_block_id) | |||
| query_dict: Dict[str, Any] = {} | |||
| query_dict: dict[str, Any] = {} | |||
| res = requests.request( | |||
| "GET", | |||
| @@ -346,7 +346,7 @@ class NotionLoader(BaseLoader): | |||
| else: | |||
| retrieve_page_url = RETRIEVE_PAGE_URL_TMPL.format(page_id=obj_id) | |||
| query_dict: Dict[str, Any] = {} | |||
| query_dict: dict[str, Any] = {} | |||
| res = requests.request( | |||
| "GET", | |||
| @@ -1,5 +1,5 @@ | |||
| import logging | |||
| from typing import List, Optional | |||
| from typing import Optional | |||
| from langchain.document_loaders import PyPDFium2Loader | |||
| from langchain.document_loaders.base import BaseLoader | |||
| @@ -28,7 +28,7 @@ class PdfLoader(BaseLoader): | |||
| self._file_path = file_path | |||
| self._upload_file = upload_file | |||
| def load(self) -> List[Document]: | |||
| def load(self) -> list[Document]: | |||
| plaintext_file_key = '' | |||
| plaintext_file_exists = False | |||
| if self._upload_file: | |||
| @@ -1,6 +1,5 @@ | |||
| import base64 | |||
| import logging | |||
| from typing import List | |||
| from bs4 import BeautifulSoup | |||
| from langchain.document_loaders.base import BaseLoader | |||
| @@ -24,7 +23,7 @@ class UnstructuredEmailLoader(BaseLoader): | |||
| self._file_path = file_path | |||
| self._api_url = api_url | |||
| def load(self) -> List[Document]: | |||
| def load(self) -> list[Document]: | |||
| from unstructured.partition.email import partition_email | |||
| elements = partition_email(filename=self._file_path, api_url=self._api_url) | |||
| @@ -1,5 +1,4 @@ | |||
| import logging | |||
| from typing import List | |||
| from langchain.document_loaders.base import BaseLoader | |||
| from langchain.schema import Document | |||
| @@ -34,7 +33,7 @@ class UnstructuredMarkdownLoader(BaseLoader): | |||
| self._file_path = file_path | |||
| self._api_url = api_url | |||
| def load(self) -> List[Document]: | |||
| def load(self) -> list[Document]: | |||
| from unstructured.partition.md import partition_md | |||
| elements = partition_md(filename=self._file_path, api_url=self._api_url) | |||
| @@ -1,5 +1,4 @@ | |||
| import logging | |||
| from typing import List | |||
| from langchain.document_loaders.base import BaseLoader | |||
| from langchain.schema import Document | |||
| @@ -24,7 +23,7 @@ class UnstructuredMsgLoader(BaseLoader): | |||
| self._file_path = file_path | |||
| self._api_url = api_url | |||
| def load(self) -> List[Document]: | |||
| def load(self) -> list[Document]: | |||
| from unstructured.partition.msg import partition_msg | |||
| elements = partition_msg(filename=self._file_path, api_url=self._api_url) | |||
| @@ -1,5 +1,4 @@ | |||
| import logging | |||
| from typing import List | |||
| from langchain.document_loaders.base import BaseLoader | |||
| from langchain.schema import Document | |||
| @@ -23,7 +22,7 @@ class UnstructuredPPTLoader(BaseLoader): | |||
| self._file_path = file_path | |||
| self._api_url = api_url | |||
| def load(self) -> List[Document]: | |||
| def load(self) -> list[Document]: | |||
| from unstructured.partition.ppt import partition_ppt | |||
| elements = partition_ppt(filename=self._file_path, api_url=self._api_url) | |||
| @@ -1,5 +1,4 @@ | |||
| import logging | |||
| from typing import List | |||
| from langchain.document_loaders.base import BaseLoader | |||
| from langchain.schema import Document | |||
| @@ -22,7 +21,7 @@ class UnstructuredPPTXLoader(BaseLoader): | |||
| self._file_path = file_path | |||
| self._api_url = api_url | |||
| def load(self) -> List[Document]: | |||
| def load(self) -> list[Document]: | |||
| from unstructured.partition.pptx import partition_pptx | |||
| elements = partition_pptx(filename=self._file_path, api_url=self._api_url) | |||
| @@ -1,5 +1,4 @@ | |||
| import logging | |||
| from typing import List | |||
| from langchain.document_loaders.base import BaseLoader | |||
| from langchain.schema import Document | |||
| @@ -24,7 +23,7 @@ class UnstructuredTextLoader(BaseLoader): | |||
| self._file_path = file_path | |||
| self._api_url = api_url | |||
| def load(self) -> List[Document]: | |||
| def load(self) -> list[Document]: | |||
| from unstructured.partition.text import partition_text | |||
| elements = partition_text(filename=self._file_path, api_url=self._api_url) | |||
| @@ -1,5 +1,4 @@ | |||
| import logging | |||
| from typing import List | |||
| from langchain.document_loaders.base import BaseLoader | |||
| from langchain.schema import Document | |||
| @@ -24,7 +23,7 @@ class UnstructuredXmlLoader(BaseLoader): | |||
| self._file_path = file_path | |||
| self._api_url = api_url | |||
| def load(self) -> List[Document]: | |||
| def load(self) -> list[Document]: | |||
| from unstructured.partition.xml import partition_xml | |||
| elements = partition_xml(filename=self._file_path, xml_keep_tags=True, api_url=self._api_url) | |||
| @@ -1,4 +1,5 @@ | |||
| from typing import Any, Dict, Optional, Sequence, cast | |||
| from collections.abc import Sequence | |||
| from typing import Any, Optional, cast | |||
| from langchain.schema import Document | |||
| from sqlalchemy import func | |||
| @@ -22,10 +23,10 @@ class DatasetDocumentStore: | |||
| self._document_id = document_id | |||
| @classmethod | |||
| def from_dict(cls, config_dict: Dict[str, Any]) -> "DatasetDocumentStore": | |||
| def from_dict(cls, config_dict: dict[str, Any]) -> "DatasetDocumentStore": | |||
| return cls(**config_dict) | |||
| def to_dict(self) -> Dict[str, Any]: | |||
| def to_dict(self) -> dict[str, Any]: | |||
| """Serialize to dict.""" | |||
| return { | |||
| "dataset_id": self._dataset.id, | |||
| @@ -40,7 +41,7 @@ class DatasetDocumentStore: | |||
| return self._user_id | |||
| @property | |||
| def docs(self) -> Dict[str, Document]: | |||
| def docs(self) -> dict[str, Document]: | |||
| document_segments = db.session.query(DocumentSegment).filter( | |||
| DocumentSegment.dataset_id == self._dataset.id | |||
| ).all() | |||
| @@ -1,6 +1,6 @@ | |||
| import base64 | |||
| import logging | |||
| from typing import List, Optional, cast | |||
| from typing import Optional, cast | |||
| import numpy as np | |||
| from langchain.embeddings.base import Embeddings | |||
| @@ -21,7 +21,7 @@ class CacheEmbedding(Embeddings): | |||
| self._model_instance = model_instance | |||
| self._user = user | |||
| def embed_documents(self, texts: List[str]) -> List[List[float]]: | |||
| def embed_documents(self, texts: list[str]) -> list[list[float]]: | |||
| """Embed search docs in batches of 10.""" | |||
| text_embeddings = [] | |||
| try: | |||
| @@ -52,7 +52,7 @@ class CacheEmbedding(Embeddings): | |||
| return text_embeddings | |||
| def embed_query(self, text: str) -> List[float]: | |||
| def embed_query(self, text: str) -> list[float]: | |||
| """Embed query text.""" | |||
| # use doc embedding cache or store if not exists | |||
| hash = helper.generate_text_hash(text) | |||
| @@ -1,8 +1,9 @@ | |||
| import datetime | |||
| import json | |||
| import logging | |||
| from collections.abc import Iterator | |||
| from json import JSONDecodeError | |||
| from typing import Dict, Iterator, List, Optional, Tuple | |||
| from typing import Optional | |||
| from pydantic import BaseModel | |||
| @@ -135,7 +136,7 @@ class ProviderConfiguration(BaseModel): | |||
| if self.provider.provider_credential_schema else [] | |||
| ) | |||
| def custom_credentials_validate(self, credentials: dict) -> Tuple[Provider, dict]: | |||
| def custom_credentials_validate(self, credentials: dict) -> tuple[Provider, dict]: | |||
| """ | |||
| Validate custom credentials. | |||
| :param credentials: provider credentials | |||
| @@ -282,7 +283,7 @@ class ProviderConfiguration(BaseModel): | |||
| return None | |||
| def custom_model_credentials_validate(self, model_type: ModelType, model: str, credentials: dict) \ | |||
| -> Tuple[ProviderModel, dict]: | |||
| -> tuple[ProviderModel, dict]: | |||
| """ | |||
| Validate custom model credentials. | |||
| @@ -711,7 +712,7 @@ class ProviderConfigurations(BaseModel): | |||
| Model class for provider configuration dict. | |||
| """ | |||
| tenant_id: str | |||
| configurations: Dict[str, ProviderConfiguration] = {} | |||
| configurations: dict[str, ProviderConfiguration] = {} | |||
| def __init__(self, tenant_id: str): | |||
| super().__init__(tenant_id=tenant_id) | |||
| @@ -759,7 +760,7 @@ class ProviderConfigurations(BaseModel): | |||
| return all_models | |||
| def to_list(self) -> List[ProviderConfiguration]: | |||
| def to_list(self) -> list[ProviderConfiguration]: | |||
| """ | |||
| Convert to list. | |||
| @@ -61,7 +61,7 @@ class Extensible: | |||
| builtin_file_path = os.path.join(subdir_path, '__builtin__') | |||
| if os.path.exists(builtin_file_path): | |||
| with open(builtin_file_path, 'r', encoding='utf-8') as f: | |||
| with open(builtin_file_path, encoding='utf-8') as f: | |||
| position = int(f.read().strip()) | |||
| if (extension_name + '.py') not in file_names: | |||
| @@ -93,7 +93,7 @@ class Extensible: | |||
| json_path = os.path.join(subdir_path, 'schema.json') | |||
| json_data = {} | |||
| if os.path.exists(json_path): | |||
| with open(json_path, 'r', encoding='utf-8') as f: | |||
| with open(json_path, encoding='utf-8') as f: | |||
| json_data = json.load(f) | |||
| extensions[extension_name] = ModuleExtension( | |||
| @@ -2,7 +2,7 @@ import json | |||
| import logging | |||
| from datetime import datetime | |||
| from mimetypes import guess_extension | |||
| from typing import List, Optional, Tuple, Union, cast | |||
| from typing import Optional, Union, cast | |||
| from core.app_runner.app_runner import AppRunner | |||
| from core.application_queue_manager import ApplicationQueueManager | |||
| @@ -50,7 +50,7 @@ class BaseAssistantApplicationRunner(AppRunner): | |||
| message: Message, | |||
| user_id: str, | |||
| memory: Optional[TokenBufferMemory] = None, | |||
| prompt_messages: Optional[List[PromptMessage]] = None, | |||
| prompt_messages: Optional[list[PromptMessage]] = None, | |||
| variables_pool: Optional[ToolRuntimeVariablePool] = None, | |||
| db_variables: Optional[ToolConversationVariables] = None, | |||
| model_instance: ModelInstance = None | |||
| @@ -122,7 +122,7 @@ class BaseAssistantApplicationRunner(AppRunner): | |||
| return app_orchestration_config | |||
| def _convert_tool_response_to_str(self, tool_response: List[ToolInvokeMessage]) -> str: | |||
| def _convert_tool_response_to_str(self, tool_response: list[ToolInvokeMessage]) -> str: | |||
| """ | |||
| Handle tool response | |||
| """ | |||
| @@ -140,7 +140,7 @@ class BaseAssistantApplicationRunner(AppRunner): | |||
| return result | |||
| def _convert_tool_to_prompt_message_tool(self, tool: AgentToolEntity) -> Tuple[PromptMessageTool, Tool]: | |||
| def _convert_tool_to_prompt_message_tool(self, tool: AgentToolEntity) -> tuple[PromptMessageTool, Tool]: | |||
| """ | |||
| convert tool to prompt message tool | |||
| """ | |||
| @@ -325,7 +325,7 @@ class BaseAssistantApplicationRunner(AppRunner): | |||
| return prompt_tool | |||
| def extract_tool_response_binary(self, tool_response: List[ToolInvokeMessage]) -> List[ToolInvokeMessageBinary]: | |||
| def extract_tool_response_binary(self, tool_response: list[ToolInvokeMessage]) -> list[ToolInvokeMessageBinary]: | |||
| """ | |||
| Extract tool response binary | |||
| """ | |||
| @@ -356,7 +356,7 @@ class BaseAssistantApplicationRunner(AppRunner): | |||
| return result | |||
| def create_message_files(self, messages: List[ToolInvokeMessageBinary]) -> List[Tuple[MessageFile, bool]]: | |||
| def create_message_files(self, messages: list[ToolInvokeMessageBinary]) -> list[tuple[MessageFile, bool]]: | |||
| """ | |||
| Create message file | |||
| @@ -404,7 +404,7 @@ class BaseAssistantApplicationRunner(AppRunner): | |||
| return result | |||
| def create_agent_thought(self, message_id: str, message: str, | |||
| tool_name: str, tool_input: str, messages_ids: List[str] | |||
| tool_name: str, tool_input: str, messages_ids: list[str] | |||
| ) -> MessageAgentThought: | |||
| """ | |||
| Create agent thought | |||
| @@ -449,7 +449,7 @@ class BaseAssistantApplicationRunner(AppRunner): | |||
| thought: str, | |||
| observation: str, | |||
| answer: str, | |||
| messages_ids: List[str], | |||
| messages_ids: list[str], | |||
| llm_usage: LLMUsage = None) -> MessageAgentThought: | |||
| """ | |||
| Save agent thought | |||
| @@ -505,7 +505,7 @@ class BaseAssistantApplicationRunner(AppRunner): | |||
| db.session.commit() | |||
| def get_history_prompt_messages(self) -> List[PromptMessage]: | |||
| def get_history_prompt_messages(self) -> list[PromptMessage]: | |||
| """ | |||
| Get history prompt messages | |||
| """ | |||
| @@ -516,7 +516,7 @@ class BaseAssistantApplicationRunner(AppRunner): | |||
| return self.history_prompt_messages | |||
| def transform_tool_invoke_messages(self, messages: List[ToolInvokeMessage]) -> List[ToolInvokeMessage]: | |||
| def transform_tool_invoke_messages(self, messages: list[ToolInvokeMessage]) -> list[ToolInvokeMessage]: | |||
| """ | |||
| Transform tool message into agent thought | |||
| """ | |||
| @@ -1,6 +1,7 @@ | |||
| import json | |||
| import re | |||
| from typing import Dict, Generator, List, Literal, Union | |||
| from collections.abc import Generator | |||
| from typing import Literal, Union | |||
| from core.application_queue_manager import PublishFrom | |||
| from core.entities.application_entities import AgentPromptEntity, AgentScratchpadUnit | |||
| @@ -29,7 +30,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner): | |||
| def run(self, conversation: Conversation, | |||
| message: Message, | |||
| query: str, | |||
| inputs: Dict[str, str], | |||
| inputs: dict[str, str], | |||
| ) -> Union[Generator, LLMResult]: | |||
| """ | |||
| Run Cot agent application | |||
| @@ -37,7 +38,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner): | |||
| app_orchestration_config = self.app_orchestration_config | |||
| self._repack_app_orchestration_config(app_orchestration_config) | |||
| agent_scratchpad: List[AgentScratchpadUnit] = [] | |||
| agent_scratchpad: list[AgentScratchpadUnit] = [] | |||
| # check model mode | |||
| if self.app_orchestration_config.model_config.mode == "completion": | |||
| @@ -56,7 +57,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner): | |||
| prompt_messages = self.history_prompt_messages | |||
| # convert tools into ModelRuntime Tool format | |||
| prompt_messages_tools: List[PromptMessageTool] = [] | |||
| prompt_messages_tools: list[PromptMessageTool] = [] | |||
| tool_instances = {} | |||
| for tool in self.app_orchestration_config.agent.tools if self.app_orchestration_config.agent else []: | |||
| try: | |||
| @@ -83,7 +84,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner): | |||
| } | |||
| final_answer = '' | |||
| def increase_usage(final_llm_usage_dict: Dict[str, LLMUsage], usage: LLMUsage): | |||
| def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): | |||
| if not final_llm_usage_dict['usage']: | |||
| final_llm_usage_dict['usage'] = usage | |||
| else: | |||
| @@ -493,7 +494,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner): | |||
| if not next_iteration.find("{{observation}}") >= 0: | |||
| raise ValueError("{{observation}} is required in next_iteration") | |||
| def _convert_scratchpad_list_to_str(self, agent_scratchpad: List[AgentScratchpadUnit]) -> str: | |||
| def _convert_scratchpad_list_to_str(self, agent_scratchpad: list[AgentScratchpadUnit]) -> str: | |||
| """ | |||
| convert agent scratchpad list to str | |||
| """ | |||
| @@ -506,13 +507,13 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner): | |||
| return result | |||
| def _organize_cot_prompt_messages(self, mode: Literal["completion", "chat"], | |||
| prompt_messages: List[PromptMessage], | |||
| tools: List[PromptMessageTool], | |||
| agent_scratchpad: List[AgentScratchpadUnit], | |||
| prompt_messages: list[PromptMessage], | |||
| tools: list[PromptMessageTool], | |||
| agent_scratchpad: list[AgentScratchpadUnit], | |||
| agent_prompt_message: AgentPromptEntity, | |||
| instruction: str, | |||
| input: str, | |||
| ) -> List[PromptMessage]: | |||
| ) -> list[PromptMessage]: | |||
| """ | |||
| organize chain of thought prompt messages, a standard prompt message is like: | |||
| Respond to the human as helpfully and accurately as possible. | |||
| @@ -1,6 +1,7 @@ | |||
| import json | |||
| import logging | |||
| from typing import Any, Dict, Generator, List, Tuple, Union | |||
| from collections.abc import Generator | |||
| from typing import Any, Union | |||
| from core.application_queue_manager import PublishFrom | |||
| from core.features.assistant_base_runner import BaseAssistantApplicationRunner | |||
| @@ -44,7 +45,7 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner): | |||
| ) | |||
| # convert tools into ModelRuntime Tool format | |||
| prompt_messages_tools: List[PromptMessageTool] = [] | |||
| prompt_messages_tools: list[PromptMessageTool] = [] | |||
| tool_instances = {} | |||
| for tool in self.app_orchestration_config.agent.tools if self.app_orchestration_config.agent else []: | |||
| try: | |||
| @@ -70,13 +71,13 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner): | |||
| # continue to run until there is not any tool call | |||
| function_call_state = True | |||
| agent_thoughts: List[MessageAgentThought] = [] | |||
| agent_thoughts: list[MessageAgentThought] = [] | |||
| llm_usage = { | |||
| 'usage': None | |||
| } | |||
| final_answer = '' | |||
| def increase_usage(final_llm_usage_dict: Dict[str, LLMUsage], usage: LLMUsage): | |||
| def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): | |||
| if not final_llm_usage_dict['usage']: | |||
| final_llm_usage_dict['usage'] = usage | |||
| else: | |||
| @@ -117,7 +118,7 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner): | |||
| callbacks=[], | |||
| ) | |||
| tool_calls: List[Tuple[str, str, Dict[str, Any]]] = [] | |||
| tool_calls: list[tuple[str, str, dict[str, Any]]] = [] | |||
| # save full response | |||
| response = '' | |||
| @@ -364,7 +365,7 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner): | |||
| return True | |||
| return False | |||
| def extract_tool_calls(self, llm_result_chunk: LLMResultChunk) -> Union[None, List[Tuple[str, str, Dict[str, Any]]]]: | |||
| def extract_tool_calls(self, llm_result_chunk: LLMResultChunk) -> Union[None, list[tuple[str, str, dict[str, Any]]]]: | |||
| """ | |||
| Extract tool calls from llm result chunk | |||
| @@ -381,7 +382,7 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner): | |||
| return tool_calls | |||
| def extract_blocking_tool_calls(self, llm_result: LLMResult) -> Union[None, List[Tuple[str, str, Dict[str, Any]]]]: | |||
| def extract_blocking_tool_calls(self, llm_result: LLMResult) -> Union[None, list[tuple[str, str, dict[str, Any]]]]: | |||
| """ | |||
| Extract blocking tool calls from llm result | |||
| @@ -1,4 +1,4 @@ | |||
| from typing import List, Optional, cast | |||
| from typing import Optional, cast | |||
| from langchain.tools import BaseTool | |||
| @@ -96,7 +96,7 @@ class DatasetRetrievalFeature: | |||
| return_resource: bool, | |||
| invoke_from: InvokeFrom, | |||
| hit_callback: DatasetIndexToolCallbackHandler) \ | |||
| -> Optional[List[BaseTool]]: | |||
| -> Optional[list[BaseTool]]: | |||
| """ | |||
| A dataset tool is a tool that can be used to retrieve information from a dataset | |||
| :param tenant_id: tenant id | |||
| @@ -2,7 +2,7 @@ import concurrent | |||
| import json | |||
| import logging | |||
| from concurrent.futures import ThreadPoolExecutor | |||
| from typing import Optional, Tuple | |||
| from typing import Optional | |||
| from flask import Flask, current_app | |||
| @@ -62,7 +62,7 @@ class ExternalDataFetchFeature: | |||
| app_id: str, | |||
| external_data_tool: ExternalDataVariableEntity, | |||
| inputs: dict, | |||
| query: str) -> Tuple[Optional[str], Optional[str]]: | |||
| query: str) -> tuple[Optional[str], Optional[str]]: | |||
| """ | |||
| Query external data tool. | |||
| :param flask_app: flask app | |||
| @@ -1,5 +1,4 @@ | |||
| import logging | |||
| from typing import Tuple | |||
| from core.entities.application_entities import AppOrchestrationConfigEntity | |||
| from core.moderation.base import ModerationAction, ModerationException | |||
| @@ -13,7 +12,7 @@ class ModerationFeature: | |||
| tenant_id: str, | |||
| app_orchestration_config_entity: AppOrchestrationConfigEntity, | |||
| inputs: dict, | |||
| query: str) -> Tuple[bool, dict, str]: | |||
| query: str) -> tuple[bool, dict, str]: | |||
| """ | |||
| Process sensitive_word_avoidance. | |||
| :param app_id: app id | |||
| @@ -1,4 +1,4 @@ | |||
| from typing import Dict, List, Optional, Union | |||
| from typing import Optional, Union | |||
| import requests | |||
| @@ -15,8 +15,8 @@ class MessageFileParser: | |||
| self.tenant_id = tenant_id | |||
| self.app_id = app_id | |||
| def validate_and_transform_files_arg(self, files: List[dict], app_model_config: AppModelConfig, | |||
| user: Union[Account, EndUser]) -> List[FileObj]: | |||
| def validate_and_transform_files_arg(self, files: list[dict], app_model_config: AppModelConfig, | |||
| user: Union[Account, EndUser]) -> list[FileObj]: | |||
| """ | |||
| validate and transform files arg | |||
| @@ -96,7 +96,7 @@ class MessageFileParser: | |||
| # return all file objs | |||
| return new_files | |||
| def transform_message_files(self, files: List[MessageFile], app_model_config: Optional[AppModelConfig]) -> List[FileObj]: | |||
| def transform_message_files(self, files: list[MessageFile], app_model_config: Optional[AppModelConfig]) -> list[FileObj]: | |||
| """ | |||
| transform message files | |||
| @@ -110,8 +110,8 @@ class MessageFileParser: | |||
| # return all file objs | |||
| return [file_obj for file_objs in type_file_objs.values() for file_obj in file_objs] | |||
| def _to_file_objs(self, files: List[Union[Dict, MessageFile]], | |||
| file_upload_config: dict) -> Dict[FileType, List[FileObj]]: | |||
| def _to_file_objs(self, files: list[Union[dict, MessageFile]], | |||
| file_upload_config: dict) -> dict[FileType, list[FileObj]]: | |||
| """ | |||
| transform files to file objs | |||
| @@ -119,7 +119,7 @@ class MessageFileParser: | |||
| :param file_upload_config: | |||
| :return: | |||
| """ | |||
| type_file_objs: Dict[FileType, List[FileObj]] = { | |||
| type_file_objs: dict[FileType, list[FileObj]] = { | |||
| # Currently only support image | |||
| FileType.IMAGE: [] | |||
| } | |||
| @@ -1,7 +1,7 @@ | |||
| from __future__ import annotations | |||
| from abc import ABC, abstractmethod | |||
| from typing import Any, List | |||
| from typing import Any | |||
| from langchain.schema import BaseRetriever, Document | |||
| @@ -53,7 +53,7 @@ class BaseIndex(ABC): | |||
| def search( | |||
| self, query: str, | |||
| **kwargs: Any | |||
| ) -> List[Document]: | |||
| ) -> list[Document]: | |||
| raise NotImplementedError | |||
| def delete(self) -> None: | |||
| @@ -1,5 +1,4 @@ | |||
| import re | |||
| from typing import Set | |||
| import jieba | |||
| from jieba.analyse import default_tfidf | |||
| @@ -12,7 +11,7 @@ class JiebaKeywordTableHandler: | |||
| def __init__(self): | |||
| default_tfidf.stop_words = STOPWORDS | |||
| def extract_keywords(self, text: str, max_keywords_per_chunk: int = 10) -> Set[str]: | |||
| def extract_keywords(self, text: str, max_keywords_per_chunk: int = 10) -> set[str]: | |||
| """Extract keywords with JIEBA tfidf.""" | |||
| keywords = jieba.analyse.extract_tags( | |||
| sentence=text, | |||
| @@ -21,7 +20,7 @@ class JiebaKeywordTableHandler: | |||
| return set(self._expand_tokens_with_subtokens(keywords)) | |||
| def _expand_tokens_with_subtokens(self, tokens: Set[str]) -> Set[str]: | |||
| def _expand_tokens_with_subtokens(self, tokens: set[str]) -> set[str]: | |||
| """Get subtokens from a list of tokens., filtering for stopwords.""" | |||
| results = set() | |||
| for token in tokens: | |||
| @@ -1,6 +1,6 @@ | |||
| import json | |||
| from collections import defaultdict | |||
| from typing import Any, Dict, List, Optional | |||
| from typing import Any, Optional | |||
| from langchain.schema import BaseRetriever, Document | |||
| from pydantic import BaseModel, Extra, Field | |||
| @@ -116,7 +116,7 @@ class KeywordTableIndex(BaseIndex): | |||
| def search( | |||
| self, query: str, | |||
| **kwargs: Any | |||
| ) -> List[Document]: | |||
| ) -> list[Document]: | |||
| keyword_table = self._get_dataset_keyword_table() | |||
| search_kwargs = kwargs.get('search_kwargs') if kwargs.get('search_kwargs') else {} | |||
| @@ -221,7 +221,7 @@ class KeywordTableIndex(BaseIndex): | |||
| keywords = keyword_table_handler.extract_keywords(query) | |||
| # go through text chunks in order of most matching keywords | |||
| chunk_indices_count: Dict[str, int] = defaultdict(int) | |||
| chunk_indices_count: dict[str, int] = defaultdict(int) | |||
| keywords = [keyword for keyword in keywords if keyword in set(keyword_table.keys())] | |||
| for keyword in keywords: | |||
| for node_id in keyword_table[keyword]: | |||
| @@ -235,7 +235,7 @@ class KeywordTableIndex(BaseIndex): | |||
| return sorted_chunk_indices[: k] | |||
| def _update_segment_keywords(self, dataset_id: str, node_id: str, keywords: List[str]): | |||
| def _update_segment_keywords(self, dataset_id: str, node_id: str, keywords: list[str]): | |||
| document_segment = db.session.query(DocumentSegment).filter( | |||
| DocumentSegment.dataset_id == dataset_id, | |||
| DocumentSegment.index_node_id == node_id | |||
| @@ -244,7 +244,7 @@ class KeywordTableIndex(BaseIndex): | |||
| document_segment.keywords = keywords | |||
| db.session.commit() | |||
| def create_segment_keywords(self, node_id: str, keywords: List[str]): | |||
| def create_segment_keywords(self, node_id: str, keywords: list[str]): | |||
| keyword_table = self._get_dataset_keyword_table() | |||
| self._update_segment_keywords(self.dataset.id, node_id, keywords) | |||
| keyword_table = self._add_text_to_keyword_table(keyword_table, node_id, keywords) | |||
| @@ -266,7 +266,7 @@ class KeywordTableIndex(BaseIndex): | |||
| keyword_table = self._add_text_to_keyword_table(keyword_table, segment.index_node_id, list(keywords)) | |||
| self._save_dataset_keyword_table(keyword_table) | |||
| def update_segment_keywords_index(self, node_id: str, keywords: List[str]): | |||
| def update_segment_keywords_index(self, node_id: str, keywords: list[str]): | |||
| keyword_table = self._get_dataset_keyword_table() | |||
| keyword_table = self._add_text_to_keyword_table(keyword_table, node_id, keywords) | |||
| self._save_dataset_keyword_table(keyword_table) | |||
| @@ -282,7 +282,7 @@ class KeywordTableRetriever(BaseRetriever, BaseModel): | |||
| extra = Extra.forbid | |||
| arbitrary_types_allowed = True | |||
| def get_relevant_documents(self, query: str) -> List[Document]: | |||
| def get_relevant_documents(self, query: str) -> list[Document]: | |||
| """Get documents relevant for a query. | |||
| Args: | |||
| @@ -293,7 +293,7 @@ class KeywordTableRetriever(BaseRetriever, BaseModel): | |||
| """ | |||
| return self.index.search(query, **self.search_kwargs) | |||
| async def aget_relevant_documents(self, query: str) -> List[Document]: | |||
| async def aget_relevant_documents(self, query: str) -> list[Document]: | |||
| raise NotImplementedError("KeywordTableRetriever does not support async") | |||
| @@ -1,7 +1,7 @@ | |||
| import json | |||
| import logging | |||
| from abc import abstractmethod | |||
| from typing import Any, List, cast | |||
| from typing import Any, cast | |||
| from langchain.embeddings.base import Embeddings | |||
| from langchain.schema import BaseRetriever, Document | |||
| @@ -43,13 +43,13 @@ class BaseVectorIndex(BaseIndex): | |||
| def search_by_full_text_index( | |||
| self, query: str, | |||
| **kwargs: Any | |||
| ) -> List[Document]: | |||
| ) -> list[Document]: | |||
| raise NotImplementedError | |||
| def search( | |||
| self, query: str, | |||
| **kwargs: Any | |||
| ) -> List[Document]: | |||
| ) -> list[Document]: | |||
| vector_store = self._get_vector_store() | |||
| vector_store = cast(self._get_vector_store_class(), vector_store) | |||
| @@ -1,4 +1,4 @@ | |||
| from typing import Any, List, cast | |||
| from typing import Any, cast | |||
| from langchain.embeddings.base import Embeddings | |||
| from langchain.schema import Document | |||
| @@ -160,6 +160,6 @@ class MilvusVectorIndex(BaseVectorIndex): | |||
| ], | |||
| )) | |||
| def search_by_full_text_index(self, query: str, **kwargs: Any) -> List[Document]: | |||
| def search_by_full_text_index(self, query: str, **kwargs: Any) -> list[Document]: | |||
| # milvus/zilliz doesn't support bm25 search | |||
| return [] | |||
| @@ -1,5 +1,5 @@ | |||
| import os | |||
| from typing import Any, List, Optional, cast | |||
| from typing import Any, Optional, cast | |||
| import qdrant_client | |||
| from langchain.embeddings.base import Embeddings | |||
| @@ -210,7 +210,7 @@ class QdrantVectorIndex(BaseVectorIndex): | |||
| return False | |||
| def search_by_full_text_index(self, query: str, **kwargs: Any) -> List[Document]: | |||
| def search_by_full_text_index(self, query: str, **kwargs: Any) -> list[Document]: | |||
| vector_store = self._get_vector_store() | |||
| vector_store = cast(self._get_vector_store_class(), vector_store) | |||
| @@ -1,4 +1,4 @@ | |||
| from typing import Any, List, Optional, cast | |||
| from typing import Any, Optional, cast | |||
| import requests | |||
| import weaviate | |||
| @@ -172,7 +172,7 @@ class WeaviateVectorIndex(BaseVectorIndex): | |||
| return False | |||
| def search_by_full_text_index(self, query: str, **kwargs: Any) -> List[Document]: | |||
| def search_by_full_text_index(self, query: str, **kwargs: Any) -> list[Document]: | |||
| vector_store = self._get_vector_store() | |||
| vector_store = cast(self._get_vector_store_class(), vector_store) | |||
| return vector_store.similarity_search_by_bm25(query, kwargs.get('top_k', 2), **kwargs) | |||
| @@ -5,7 +5,7 @@ import re | |||
| import threading | |||
| import time | |||
| import uuid | |||
| from typing import List, Optional, cast | |||
| from typing import Optional, cast | |||
| from flask import Flask, current_app | |||
| from flask_login import current_user | |||
| @@ -40,7 +40,7 @@ class IndexingRunner: | |||
| self.storage = storage | |||
| self.model_manager = ModelManager() | |||
| def run(self, dataset_documents: List[DatasetDocument]): | |||
| def run(self, dataset_documents: list[DatasetDocument]): | |||
| """Run the indexing process.""" | |||
| for dataset_document in dataset_documents: | |||
| try: | |||
| @@ -238,7 +238,7 @@ class IndexingRunner: | |||
| dataset_document.stopped_at = datetime.datetime.utcnow() | |||
| db.session.commit() | |||
| def file_indexing_estimate(self, tenant_id: str, file_details: List[UploadFile], tmp_processing_rule: dict, | |||
| def file_indexing_estimate(self, tenant_id: str, file_details: list[UploadFile], tmp_processing_rule: dict, | |||
| doc_form: str = None, doc_language: str = 'English', dataset_id: str = None, | |||
| indexing_technique: str = 'economy') -> dict: | |||
| """ | |||
| @@ -494,7 +494,7 @@ class IndexingRunner: | |||
| "preview": preview_texts | |||
| } | |||
| def _load_data(self, dataset_document: DatasetDocument, automatic: bool = False) -> List[Document]: | |||
| def _load_data(self, dataset_document: DatasetDocument, automatic: bool = False) -> list[Document]: | |||
| # load file | |||
| if dataset_document.data_source_type not in ["upload_file", "notion_import"]: | |||
| return [] | |||
| @@ -526,7 +526,7 @@ class IndexingRunner: | |||
| ) | |||
| # replace doc id to document model id | |||
| text_docs = cast(List[Document], text_docs) | |||
| text_docs = cast(list[Document], text_docs) | |||
| for text_doc in text_docs: | |||
| # remove invalid symbol | |||
| text_doc.page_content = self.filter_string(text_doc.page_content) | |||
| @@ -540,7 +540,7 @@ class IndexingRunner: | |||
| text = re.sub(r'\|>', '>', text) | |||
| text = re.sub(r'[\x00-\x08\x0B\x0C\x0E-\x1F\x7F\xEF\xBF\xBE]', '', text) | |||
| # Unicode U+FFFE | |||
| text = re.sub(u'\uFFFE', '', text) | |||
| text = re.sub('\uFFFE', '', text) | |||
| return text | |||
| def _get_splitter(self, processing_rule: DatasetProcessRule, | |||
| @@ -577,9 +577,9 @@ class IndexingRunner: | |||
| return character_splitter | |||
| def _step_split(self, text_docs: List[Document], splitter: TextSplitter, | |||
| def _step_split(self, text_docs: list[Document], splitter: TextSplitter, | |||
| dataset: Dataset, dataset_document: DatasetDocument, processing_rule: DatasetProcessRule) \ | |||
| -> List[Document]: | |||
| -> list[Document]: | |||
| """ | |||
| Split the text documents into documents and save them to the document segment. | |||
| """ | |||
| @@ -624,9 +624,9 @@ class IndexingRunner: | |||
| return documents | |||
| def _split_to_documents(self, text_docs: List[Document], splitter: TextSplitter, | |||
| def _split_to_documents(self, text_docs: list[Document], splitter: TextSplitter, | |||
| processing_rule: DatasetProcessRule, tenant_id: str, | |||
| document_form: str, document_language: str) -> List[Document]: | |||
| document_form: str, document_language: str) -> list[Document]: | |||
| """ | |||
| Split the text documents into nodes. | |||
| """ | |||
| @@ -699,8 +699,8 @@ class IndexingRunner: | |||
| all_qa_documents.extend(format_documents) | |||
| def _split_to_documents_for_estimate(self, text_docs: List[Document], splitter: TextSplitter, | |||
| processing_rule: DatasetProcessRule) -> List[Document]: | |||
| def _split_to_documents_for_estimate(self, text_docs: list[Document], splitter: TextSplitter, | |||
| processing_rule: DatasetProcessRule) -> list[Document]: | |||
| """ | |||
| Split the text documents into nodes. | |||
| """ | |||
| @@ -770,7 +770,7 @@ class IndexingRunner: | |||
| for q, a in matches if q and a | |||
| ] | |||
| def _build_index(self, dataset: Dataset, dataset_document: DatasetDocument, documents: List[Document]) -> None: | |||
| def _build_index(self, dataset: Dataset, dataset_document: DatasetDocument, documents: list[Document]) -> None: | |||
| """ | |||
| Build the index for the document. | |||
| """ | |||
| @@ -877,7 +877,7 @@ class IndexingRunner: | |||
| DocumentSegment.query.filter_by(document_id=dataset_document_id).update(update_params) | |||
| db.session.commit() | |||
| def batch_add_segments(self, segments: List[DocumentSegment], dataset: Dataset): | |||
| def batch_add_segments(self, segments: list[DocumentSegment], dataset: Dataset): | |||
| """ | |||
| Batch add segments index processing | |||
| """ | |||
| @@ -1,4 +1,5 @@ | |||
| from typing import IO, Generator, List, Optional, Union, cast | |||
| from collections.abc import Generator | |||
| from typing import IO, Optional, Union, cast | |||
| from core.entities.provider_configuration import ProviderModelBundle | |||
| from core.errors.error import ProviderTokenNotInitError | |||
| @@ -47,7 +48,7 @@ class ModelInstance: | |||
| return credentials | |||
| def invoke_llm(self, prompt_messages: list[PromptMessage], model_parameters: Optional[dict] = None, | |||
| tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, | |||
| tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, | |||
| stream: bool = True, user: Optional[str] = None, callbacks: list[Callback] = None) \ | |||
| -> Union[LLMResult, Generator]: | |||
| """ | |||
| @@ -1,5 +1,5 @@ | |||
| from abc import ABC | |||
| from typing import List, Optional | |||
| from typing import Optional | |||
| from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk | |||
| from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool | |||
| @@ -23,7 +23,7 @@ class Callback(ABC): | |||
| def on_before_invoke(self, llm_instance: AIModel, model: str, credentials: dict, | |||
| prompt_messages: list[PromptMessage], model_parameters: dict, | |||
| tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, | |||
| tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, | |||
| stream: bool = True, user: Optional[str] = None) -> None: | |||
| """ | |||
| Before invoke callback | |||
| @@ -42,7 +42,7 @@ class Callback(ABC): | |||
| def on_new_chunk(self, llm_instance: AIModel, chunk: LLMResultChunk, model: str, credentials: dict, | |||
| prompt_messages: list[PromptMessage], model_parameters: dict, | |||
| tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, | |||
| tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, | |||
| stream: bool = True, user: Optional[str] = None): | |||
| """ | |||
| On new chunk callback | |||
| @@ -62,7 +62,7 @@ class Callback(ABC): | |||
| def on_after_invoke(self, llm_instance: AIModel, result: LLMResult, model: str, credentials: dict, | |||
| prompt_messages: list[PromptMessage], model_parameters: dict, | |||
| tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, | |||
| tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, | |||
| stream: bool = True, user: Optional[str] = None) -> None: | |||
| """ | |||
| After invoke callback | |||
| @@ -82,7 +82,7 @@ class Callback(ABC): | |||
| def on_invoke_error(self, llm_instance: AIModel, ex: Exception, model: str, credentials: dict, | |||
| prompt_messages: list[PromptMessage], model_parameters: dict, | |||
| tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, | |||
| tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, | |||
| stream: bool = True, user: Optional[str] = None) -> None: | |||
| """ | |||
| Invoke error callback | |||
| @@ -1,7 +1,7 @@ | |||
| import json | |||
| import logging | |||
| import sys | |||
| from typing import List, Optional | |||
| from typing import Optional | |||
| from core.model_runtime.callbacks.base_callback import Callback | |||
| from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk | |||
| @@ -13,7 +13,7 @@ logger = logging.getLogger(__name__) | |||
| class LoggingCallback(Callback): | |||
| def on_before_invoke(self, llm_instance: AIModel, model: str, credentials: dict, | |||
| prompt_messages: list[PromptMessage], model_parameters: dict, | |||
| tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, | |||
| tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, | |||
| stream: bool = True, user: Optional[str] = None) -> None: | |||
| """ | |||
| Before invoke callback | |||
| @@ -60,7 +60,7 @@ class LoggingCallback(Callback): | |||
| def on_new_chunk(self, llm_instance: AIModel, chunk: LLMResultChunk, model: str, credentials: dict, | |||
| prompt_messages: list[PromptMessage], model_parameters: dict, | |||
| tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, | |||
| tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, | |||
| stream: bool = True, user: Optional[str] = None): | |||
| """ | |||
| On new chunk callback | |||
| @@ -81,7 +81,7 @@ class LoggingCallback(Callback): | |||
| def on_after_invoke(self, llm_instance: AIModel, result: LLMResult, model: str, credentials: dict, | |||
| prompt_messages: list[PromptMessage], model_parameters: dict, | |||
| tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, | |||
| tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, | |||
| stream: bool = True, user: Optional[str] = None) -> None: | |||
| """ | |||
| After invoke callback | |||
| @@ -113,7 +113,7 @@ class LoggingCallback(Callback): | |||
| def on_invoke_error(self, llm_instance: AIModel, ex: Exception, model: str, credentials: dict, | |||
| prompt_messages: list[PromptMessage], model_parameters: dict, | |||
| tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, | |||
| tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, | |||
| stream: bool = True, user: Optional[str] = None) -> None: | |||
| """ | |||
| Invoke error callback | |||
| @@ -1,8 +1,7 @@ | |||
| from typing import Dict | |||
| from core.model_runtime.entities.model_entities import DefaultParameterName | |||
| PARAMETER_RULE_TEMPLATE: Dict[DefaultParameterName, dict] = { | |||
| PARAMETER_RULE_TEMPLATE: dict[DefaultParameterName, dict] = { | |||
| DefaultParameterName.TEMPERATURE: { | |||
| 'label': { | |||
| 'en_US': 'Temperature', | |||
| @@ -153,7 +153,7 @@ class AIModel(ABC): | |||
| # read _position.yaml file | |||
| position_map = {} | |||
| if os.path.exists(position_file_path): | |||
| with open(position_file_path, 'r', encoding='utf-8') as f: | |||
| with open(position_file_path, encoding='utf-8') as f: | |||
| positions = yaml.safe_load(f) | |||
| # convert list to dict with key as model provider name, value as index | |||
| position_map = {position: index for index, position in enumerate(positions)} | |||
| @@ -161,7 +161,7 @@ class AIModel(ABC): | |||
| # traverse all model_schema_yaml_paths | |||
| for model_schema_yaml_path in model_schema_yaml_paths: | |||
| # read yaml data from yaml file | |||
| with open(model_schema_yaml_path, 'r', encoding='utf-8') as f: | |||
| with open(model_schema_yaml_path, encoding='utf-8') as f: | |||
| yaml_data = yaml.safe_load(f) | |||
| new_parameter_rules = [] | |||
| @@ -3,7 +3,8 @@ import os | |||
| import re | |||
| import time | |||
| from abc import abstractmethod | |||
| from typing import Generator, List, Optional, Union | |||
| from collections.abc import Generator | |||
| from typing import Optional, Union | |||
| from core.model_runtime.callbacks.base_callback import Callback | |||
| from core.model_runtime.callbacks.logging_callback import LoggingCallback | |||
| @@ -29,7 +30,7 @@ class LargeLanguageModel(AIModel): | |||
| def invoke(self, model: str, credentials: dict, | |||
| prompt_messages: list[PromptMessage], model_parameters: Optional[dict] = None, | |||
| tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, | |||
| tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, | |||
| stream: bool = True, user: Optional[str] = None, callbacks: list[Callback] = None) \ | |||
| -> Union[LLMResult, Generator]: | |||
| """ | |||
| @@ -122,7 +123,7 @@ class LargeLanguageModel(AIModel): | |||
| def _invoke_result_generator(self, model: str, result: Generator, credentials: dict, | |||
| prompt_messages: list[PromptMessage], model_parameters: dict, | |||
| tools: Optional[list[PromptMessageTool]] = None, | |||
| stop: Optional[List[str]] = None, stream: bool = True, | |||
| stop: Optional[list[str]] = None, stream: bool = True, | |||
| user: Optional[str] = None, callbacks: list[Callback] = None) -> Generator: | |||
| """ | |||
| Invoke result generator | |||
| @@ -186,7 +187,7 @@ class LargeLanguageModel(AIModel): | |||
| @abstractmethod | |||
| def _invoke(self, model: str, credentials: dict, | |||
| prompt_messages: list[PromptMessage], model_parameters: dict, | |||
| tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, | |||
| tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, | |||
| stream: bool = True, user: Optional[str] = None) \ | |||
| -> Union[LLMResult, Generator]: | |||
| """ | |||
| @@ -218,7 +219,7 @@ class LargeLanguageModel(AIModel): | |||
| """ | |||
| raise NotImplementedError | |||
| def enforce_stop_tokens(self, text: str, stop: List[str]) -> str: | |||
| def enforce_stop_tokens(self, text: str, stop: list[str]) -> str: | |||
| """Cut off the text as soon as any stop words occur.""" | |||
| return re.split("|".join(stop), text, maxsplit=1)[0] | |||
| @@ -329,7 +330,7 @@ class LargeLanguageModel(AIModel): | |||
| def _trigger_before_invoke_callbacks(self, model: str, credentials: dict, | |||
| prompt_messages: list[PromptMessage], model_parameters: dict, | |||
| tools: Optional[list[PromptMessageTool]] = None, | |||
| stop: Optional[List[str]] = None, stream: bool = True, | |||
| stop: Optional[list[str]] = None, stream: bool = True, | |||
| user: Optional[str] = None, callbacks: list[Callback] = None) -> None: | |||
| """ | |||
| Trigger before invoke callbacks | |||
| @@ -367,7 +368,7 @@ class LargeLanguageModel(AIModel): | |||
| def _trigger_new_chunk_callbacks(self, chunk: LLMResultChunk, model: str, credentials: dict, | |||
| prompt_messages: list[PromptMessage], model_parameters: dict, | |||
| tools: Optional[list[PromptMessageTool]] = None, | |||
| stop: Optional[List[str]] = None, stream: bool = True, | |||
| stop: Optional[list[str]] = None, stream: bool = True, | |||
| user: Optional[str] = None, callbacks: list[Callback] = None) -> None: | |||
| """ | |||
| Trigger new chunk callbacks | |||
| @@ -406,7 +407,7 @@ class LargeLanguageModel(AIModel): | |||
| def _trigger_after_invoke_callbacks(self, model: str, result: LLMResult, credentials: dict, | |||
| prompt_messages: list[PromptMessage], model_parameters: dict, | |||
| tools: Optional[list[PromptMessageTool]] = None, | |||
| stop: Optional[List[str]] = None, stream: bool = True, | |||
| stop: Optional[list[str]] = None, stream: bool = True, | |||
| user: Optional[str] = None, callbacks: list[Callback] = None) -> None: | |||
| """ | |||
| Trigger after invoke callbacks | |||
| @@ -446,7 +447,7 @@ class LargeLanguageModel(AIModel): | |||
| def _trigger_invoke_error_callbacks(self, model: str, ex: Exception, credentials: dict, | |||
| prompt_messages: list[PromptMessage], model_parameters: dict, | |||
| tools: Optional[list[PromptMessageTool]] = None, | |||
| stop: Optional[List[str]] = None, stream: bool = True, | |||
| stop: Optional[list[str]] = None, stream: bool = True, | |||
| user: Optional[str] = None, callbacks: list[Callback] = None) -> None: | |||
| """ | |||
| Trigger invoke error callbacks | |||
| @@ -527,7 +528,7 @@ class LargeLanguageModel(AIModel): | |||
| raise ValueError( | |||
| f"Model Parameter {parameter_name} should be less than or equal to {parameter_rule.max}.") | |||
| elif parameter_rule.type == ParameterType.FLOAT: | |||
| if not isinstance(parameter_value, (float, int)): | |||
| if not isinstance(parameter_value, float | int): | |||
| raise ValueError(f"Model Parameter {parameter_name} should be float.") | |||
| # validate parameter value precision | |||
| @@ -1,7 +1,6 @@ | |||
| import importlib | |||
| import os | |||
| from abc import ABC, abstractmethod | |||
| from typing import Dict | |||
| import yaml | |||
| @@ -12,7 +11,7 @@ from core.model_runtime.model_providers.__base.ai_model import AIModel | |||
| class ModelProvider(ABC): | |||
| provider_schema: ProviderEntity = None | |||
| model_instance_map: Dict[str, AIModel] = {} | |||
| model_instance_map: dict[str, AIModel] = {} | |||
| @abstractmethod | |||
| def validate_provider_credentials(self, credentials: dict) -> None: | |||
| @@ -47,7 +46,7 @@ class ModelProvider(ABC): | |||
| yaml_path = os.path.join(current_path, f'{provider_name}.yaml') | |||
| yaml_data = {} | |||
| if os.path.exists(yaml_path): | |||
| with open(yaml_path, 'r', encoding='utf-8') as f: | |||
| with open(yaml_path, encoding='utf-8') as f: | |||
| yaml_data = yaml.safe_load(f) | |||
| try: | |||
| @@ -1,4 +1,5 @@ | |||
| from typing import Generator, List, Optional, Union | |||
| from collections.abc import Generator | |||
| from typing import Optional, Union | |||
| import anthropic | |||
| from anthropic import Anthropic, Stream | |||
| @@ -29,7 +30,7 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): | |||
| def _invoke(self, model: str, credentials: dict, | |||
| prompt_messages: list[PromptMessage], model_parameters: dict, | |||
| tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, | |||
| tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, | |||
| stream: bool = True, user: Optional[str] = None) \ | |||
| -> Union[LLMResult, Generator]: | |||
| """ | |||
| @@ -90,7 +91,7 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): | |||
| def _generate(self, model: str, credentials: dict, | |||
| prompt_messages: list[PromptMessage], model_parameters: dict, | |||
| stop: Optional[List[str]] = None, stream: bool = True, | |||
| stop: Optional[list[str]] = None, stream: bool = True, | |||
| user: Optional[str] = None) -> Union[LLMResult, Generator]: | |||
| """ | |||
| Invoke large language model | |||
| @@ -255,7 +256,7 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): | |||
| return message_text | |||
| def _convert_messages_to_prompt_anthropic(self, messages: List[PromptMessage]) -> str: | |||
| def _convert_messages_to_prompt_anthropic(self, messages: list[PromptMessage]) -> str: | |||
| """ | |||
| Format a list of messages into a full prompt for the Anthropic model | |||