Selaa lähdekoodia

feat: integrate flask-orjson for improved JSON serialization performance (#23935)

tags/1.8.0
-LAN- 2 kuukautta sitten
vanhempi
commit
e340fccafb
No account linked to committer's email address

+ 2
- 0
api/app_factory.py Näytä tiedosto

ext_login, ext_login,
ext_mail, ext_mail,
ext_migrate, ext_migrate,
ext_orjson,
ext_otel, ext_otel,
ext_proxy_fix, ext_proxy_fix,
ext_redis, ext_redis,
ext_logging, ext_logging,
ext_warnings, ext_warnings,
ext_import_modules, ext_import_modules,
ext_orjson,
ext_set_secretkey, ext_set_secretkey,
ext_compress, ext_compress,
ext_code_based_extension, ext_code_based_extension,

+ 2
- 2
api/core/helper/code_executor/template_transformer.py Näytä tiedosto

from collections.abc import Mapping from collections.abc import Mapping
from typing import Any from typing import Any


from core.variables.utils import SegmentJSONEncoder
from core.variables.utils import dumps_with_segments




class TemplateTransformer(ABC): class TemplateTransformer(ABC):


@classmethod @classmethod
def serialize_inputs(cls, inputs: Mapping[str, Any]) -> str: def serialize_inputs(cls, inputs: Mapping[str, Any]) -> str:
inputs_json_str = json.dumps(inputs, ensure_ascii=False, cls=SegmentJSONEncoder).encode()
inputs_json_str = dumps_with_segments(inputs, ensure_ascii=False).encode()
input_base64_encoded = b64encode(inputs_json_str).decode("utf-8") input_base64_encoded = b64encode(inputs_json_str).decode("utf-8")
return input_base64_encoded return input_base64_encoded



+ 15
- 11
api/core/rag/datasource/keyword/jieba/jieba.py Näytä tiedosto

import json
from collections import defaultdict from collections import defaultdict
from typing import Any, Optional from typing import Any, Optional


import orjson
from pydantic import BaseModel from pydantic import BaseModel


from configs import dify_config from configs import dify_config
dataset_keyword_table = self.dataset.dataset_keyword_table dataset_keyword_table = self.dataset.dataset_keyword_table
keyword_data_source_type = dataset_keyword_table.data_source_type keyword_data_source_type = dataset_keyword_table.data_source_type
if keyword_data_source_type == "database": if keyword_data_source_type == "database":
dataset_keyword_table.keyword_table = json.dumps(keyword_table_dict, cls=SetEncoder)
dataset_keyword_table.keyword_table = dumps_with_sets(keyword_table_dict)
db.session.commit() db.session.commit()
else: else:
file_key = "keyword_files/" + self.dataset.tenant_id + "/" + self.dataset.id + ".txt" file_key = "keyword_files/" + self.dataset.tenant_id + "/" + self.dataset.id + ".txt"
if storage.exists(file_key): if storage.exists(file_key):
storage.delete(file_key) storage.delete(file_key)
storage.save(file_key, json.dumps(keyword_table_dict, cls=SetEncoder).encode("utf-8"))
storage.save(file_key, dumps_with_sets(keyword_table_dict).encode("utf-8"))


def _get_dataset_keyword_table(self) -> Optional[dict]: def _get_dataset_keyword_table(self) -> Optional[dict]:
dataset_keyword_table = self.dataset.dataset_keyword_table dataset_keyword_table = self.dataset.dataset_keyword_table
data_source_type=keyword_data_source_type, data_source_type=keyword_data_source_type,
) )
if keyword_data_source_type == "database": if keyword_data_source_type == "database":
dataset_keyword_table.keyword_table = json.dumps(
dataset_keyword_table.keyword_table = dumps_with_sets(
{ {
"__type__": "keyword_table", "__type__": "keyword_table",
"__data__": {"index_id": self.dataset.id, "summary": None, "table": {}}, "__data__": {"index_id": self.dataset.id, "summary": None, "table": {}},
},
cls=SetEncoder,
}
) )
db.session.add(dataset_keyword_table) db.session.add(dataset_keyword_table)
db.session.commit() db.session.commit()
self._save_dataset_keyword_table(keyword_table) self._save_dataset_keyword_table(keyword_table)




class SetEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, set):
return list(obj)
return super().default(obj)
def set_orjson_default(obj: Any) -> Any:
"""Default function for orjson serialization of set types"""
if isinstance(obj, set):
return list(obj)
raise TypeError(f"Object of type {type(obj).__name__} is not JSON serializable")


def dumps_with_sets(obj: Any) -> str:
"""JSON dumps with set support using orjson"""
return orjson.dumps(obj, default=set_orjson_default).decode("utf-8")

+ 20
- 13
api/core/variables/utils.py Näytä tiedosto

import json
from collections.abc import Iterable, Sequence from collections.abc import Iterable, Sequence
from typing import Any

import orjson


from .segment_group import SegmentGroup from .segment_group import SegmentGroup
from .segments import ArrayFileSegment, FileSegment, Segment from .segments import ArrayFileSegment, FileSegment, Segment
return selectors return selectors




class SegmentJSONEncoder(json.JSONEncoder):
def default(self, o):
if isinstance(o, ArrayFileSegment):
return [v.model_dump() for v in o.value]
elif isinstance(o, FileSegment):
return o.value.model_dump()
elif isinstance(o, SegmentGroup):
return [self.default(seg) for seg in o.value]
elif isinstance(o, Segment):
return o.value
else:
super().default(o)
def segment_orjson_default(o: Any) -> Any:
"""Default function for orjson serialization of Segment types"""
if isinstance(o, ArrayFileSegment):
return [v.model_dump() for v in o.value]
elif isinstance(o, FileSegment):
return o.value.model_dump()
elif isinstance(o, SegmentGroup):
return [segment_orjson_default(seg) for seg in o.value]
elif isinstance(o, Segment):
return o.value
raise TypeError(f"Object of type {type(o).__name__} is not JSON serializable")


def dumps_with_segments(obj: Any, ensure_ascii: bool = False) -> str:
"""JSON dumps with segment support using orjson"""
option = orjson.OPT_NON_STR_KEYS
return orjson.dumps(obj, default=segment_orjson_default, option=option).decode("utf-8")

+ 8
- 0
api/extensions/ext_orjson.py Näytä tiedosto

from flask_orjson import OrjsonProvider

from dify_app import DifyApp


def init_app(app: DifyApp) -> None:
"""Initialize Flask-Orjson extension for faster JSON serialization"""
app.json = OrjsonProvider(app)

+ 1
- 1
api/models/workflow.py Näytä tiedosto

value: The Segment object to store as the variable's value. value: The Segment object to store as the variable's value.
""" """
self.__value = value self.__value = value
self.value = json.dumps(value, cls=variable_utils.SegmentJSONEncoder)
self.value = variable_utils.dumps_with_segments(value)
self.value_type = value.value_type self.value_type = value.value_type


def get_node_id(self) -> str | None: def get_node_id(self) -> str | None:

+ 1
- 0
api/pyproject.toml Näytä tiedosto

"flask-cors~=6.0.0", "flask-cors~=6.0.0",
"flask-login~=0.6.3", "flask-login~=0.6.3",
"flask-migrate~=4.0.7", "flask-migrate~=4.0.7",
"flask-orjson~=2.0.0",
"flask-restful~=0.3.10", "flask-restful~=0.3.10",
"flask-sqlalchemy~=3.1.1", "flask-sqlalchemy~=3.1.1",
"gevent~=24.11.1", "gevent~=24.11.1",

+ 15
- 0
api/uv.lock Näytä tiedosto

{ name = "flask-cors" }, { name = "flask-cors" },
{ name = "flask-login" }, { name = "flask-login" },
{ name = "flask-migrate" }, { name = "flask-migrate" },
{ name = "flask-orjson" },
{ name = "flask-restful" }, { name = "flask-restful" },
{ name = "flask-sqlalchemy" }, { name = "flask-sqlalchemy" },
{ name = "gevent" }, { name = "gevent" },
{ name = "flask-cors", specifier = "~=6.0.0" }, { name = "flask-cors", specifier = "~=6.0.0" },
{ name = "flask-login", specifier = "~=0.6.3" }, { name = "flask-login", specifier = "~=0.6.3" },
{ name = "flask-migrate", specifier = "~=4.0.7" }, { name = "flask-migrate", specifier = "~=4.0.7" },
{ name = "flask-orjson", specifier = "~=2.0.0" },
{ name = "flask-restful", specifier = "~=0.3.10" }, { name = "flask-restful", specifier = "~=0.3.10" },
{ name = "flask-sqlalchemy", specifier = "~=3.1.1" }, { name = "flask-sqlalchemy", specifier = "~=3.1.1" },
{ name = "gevent", specifier = "~=24.11.1" }, { name = "gevent", specifier = "~=24.11.1" },
{ url = "https://files.pythonhosted.org/packages/93/01/587023575286236f95d2ab8a826c320375ed5ea2102bb103ed89704ffa6b/Flask_Migrate-4.0.7-py3-none-any.whl", hash = "sha256:5c532be17e7b43a223b7500d620edae33795df27c75811ddf32560f7d48ec617", size = 21127, upload-time = "2024-03-11T18:42:59.462Z" }, { url = "https://files.pythonhosted.org/packages/93/01/587023575286236f95d2ab8a826c320375ed5ea2102bb103ed89704ffa6b/Flask_Migrate-4.0.7-py3-none-any.whl", hash = "sha256:5c532be17e7b43a223b7500d620edae33795df27c75811ddf32560f7d48ec617", size = 21127, upload-time = "2024-03-11T18:42:59.462Z" },
] ]


[[package]]
name = "flask-orjson"
version = "2.0.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "flask" },
{ name = "orjson" },
]
sdist = { url = "https://files.pythonhosted.org/packages/a3/49/575796f6ddca171d82dbb12762e33166c8b8f8616c946f0a6dfbb9bc3cd6/flask_orjson-2.0.0.tar.gz", hash = "sha256:6df6631437f9bc52cf9821735f896efa5583b5f80712f7d29d9ef69a79986a9c", size = 2974, upload-time = "2024-01-15T00:03:22.236Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/f3/ca/53e14be018a2284acf799830e8cd8e0b263c0fd3dff1ad7b35f8417e7067/flask_orjson-2.0.0-py3-none-any.whl", hash = "sha256:5d15f2ba94b8d6c02aee88fc156045016e83db9eda2c30545fabd640aebaec9d", size = 3622, upload-time = "2024-01-15T00:03:17.511Z" },
]

[[package]] [[package]]
name = "flask-restful" name = "flask-restful"
version = "0.3.10" version = "0.3.10"

Loading…
Peruuta
Tallenna