Просмотр исходного кода

fix: drop some type fixme (#20344)

tags/1.4.2
yihong 5 месяцев назад
Родитель
Сommit
5a991295e0
Аккаунт пользователя с таким Email не найден

+ 12
- 11
api/core/model_runtime/utils/encoders.py Просмотреть файл

sqlalchemy_safe=sqlalchemy_safe, sqlalchemy_safe=sqlalchemy_safe,
) )
if dataclasses.is_dataclass(obj): if dataclasses.is_dataclass(obj):
# FIXME: mypy error, try to fix it instead of using type: ignore
obj_dict = dataclasses.asdict(obj) # type: ignore
return jsonable_encoder(
obj_dict,
by_alias=by_alias,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
custom_encoder=custom_encoder,
sqlalchemy_safe=sqlalchemy_safe,
)
# Ensure obj is a dataclass instance, not a dataclass type
if not isinstance(obj, type):
obj_dict = dataclasses.asdict(obj)
return jsonable_encoder(
obj_dict,
by_alias=by_alias,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
custom_encoder=custom_encoder,
sqlalchemy_safe=sqlalchemy_safe,
)
if isinstance(obj, Enum): if isinstance(obj, Enum):
return obj.value return obj.value
if isinstance(obj, PurePath): if isinstance(obj, PurePath):

+ 0
- 1
api/core/rag/datasource/vdb/baidu/baidu_vector.py Просмотреть файл

end = min(start + batch_size, total_count) end = min(start + batch_size, total_count)
rows = [] rows = []
assert len(metadatas) == total_count, "metadatas length should be equal to total_count" assert len(metadatas) == total_count, "metadatas length should be equal to total_count"
# FIXME do you need this assert?
for i in range(start, end, 1): for i in range(start, end, 1):
row = Row( row = Row(
id=metadatas[i].get("doc_id", str(uuid.uuid4())), id=metadatas[i].get("doc_id", str(uuid.uuid4())),

+ 1
- 1
api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py Просмотреть файл

return cluster_infos return cluster_infos
else: else:
response.raise_for_status() response.raise_for_status()
return [] # FIXME for mypy, This line will not be reached as raise_for_status() will raise an exception
return []

+ 0
- 1
api/core/tools/entities/tool_entities.py Просмотреть файл

:param options: the options of the parameter :param options: the options of the parameter
""" """
# convert options to ToolParameterOption # convert options to ToolParameterOption
# FIXME fix the type error
if options: if options:
option_objs = [ option_objs = [
PluginParameterOption(value=option, label=I18nObject(en_US=option, zh_Hans=option)) PluginParameterOption(value=option, label=I18nObject(en_US=option, zh_Hans=option))

+ 0
- 1
api/core/tools/utils/message_transformer.py Просмотреть файл

if not isinstance(message.message, ToolInvokeMessage.BlobMessage): if not isinstance(message.message, ToolInvokeMessage.BlobMessage):
raise ValueError("unexpected message type") raise ValueError("unexpected message type")


# FIXME: should do a type check here.
assert isinstance(message.message.blob, bytes) assert isinstance(message.message.blob, bytes)
tool_file_manager = ToolFileManager() tool_file_manager = ToolFileManager()
file = tool_file_manager.create_file_by_raw( file = tool_file_manager.create_file_by_raw(

+ 0
- 1
api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py Просмотреть файл

:param node_data: node data :param node_data: node data
:return: :return:
""" """
# FIXME: fix the type error later
variable_mapping: dict[str, Sequence[str]] = {"query": node_data.query} variable_mapping: dict[str, Sequence[str]] = {"query": node_data.query}


if node_data.instruction: if node_data.instruction:

+ 2
- 2
api/factories/variable_factory.py Просмотреть файл

raise VariableError("missing value type") raise VariableError("missing value type")
if (value := mapping.get("value")) is None: if (value := mapping.get("value")) is None:
raise VariableError("missing value") raise VariableError("missing value")
# FIXME: using Any here, fix it later
result: Any
result: Variable
match value_type: match value_type:
case SegmentType.STRING: case SegmentType.STRING:
result = StringVariable.model_validate(mapping) result = StringVariable.model_validate(mapping)

+ 1
- 2
api/schedule/clean_messages.py Просмотреть файл

while True: while True:
try: try:
# Main query with join and filter # Main query with join and filter
# FIXME:for mypy no paginate method error
messages = ( messages = (
db.session.query(Message) # type: ignore
db.session.query(Message)
.filter(Message.created_at < plan_sandbox_clean_message_day) .filter(Message.created_at < plan_sandbox_clean_message_day)
.order_by(Message.created_at.desc()) .order_by(Message.created_at.desc())
.limit(100) .limit(100)

+ 8
- 8
api/services/ops_service.py Просмотреть файл

from typing import Optional
from typing import Any, Optional


from core.ops.entities.config_entity import BaseTracingConfig
from core.ops.ops_trace_manager import OpsTraceManager, provider_config_map from core.ops.ops_trace_manager import OpsTraceManager, provider_config_map
from extensions.ext_database import db from extensions.ext_database import db
from models.model import App, TraceAppConfig from models.model import App, TraceAppConfig
except KeyError: except KeyError:
return {"error": f"Invalid tracing provider: {tracing_provider}"} return {"error": f"Invalid tracing provider: {tracing_provider}"}


config_class, other_keys = (
provider_config_map[tracing_provider]["config_class"],
provider_config_map[tracing_provider]["other_keys"],
)
# FIXME: ignore type error
default_config_instance = config_class(**tracing_config) # type: ignore
for key in other_keys: # type: ignore
provider_config: dict[str, Any] = provider_config_map[tracing_provider]
config_class: type[BaseTracingConfig] = provider_config["config_class"]
other_keys: list[str] = provider_config["other_keys"]

default_config_instance: BaseTracingConfig = config_class(**tracing_config)
for key in other_keys:
if key in tracing_config and tracing_config[key] == "": if key in tracing_config and tracing_config[key] == "":
tracing_config[key] = getattr(default_config_instance, key, None) tracing_config[key] = getattr(default_config_instance, key, None)



+ 19
- 17
api/services/website_service.py Просмотреть файл

return crawl_status_data return crawl_status_data


@classmethod @classmethod
def get_crawl_url_data(cls, job_id: str, provider: str, url: str, tenant_id: str) -> dict[Any, Any] | None:
def get_crawl_url_data(cls, job_id: str, provider: str, url: str, tenant_id: str) -> dict[str, Any] | None:
credentials = ApiKeyAuthService.get_auth_credentials(tenant_id, "website", provider) credentials = ApiKeyAuthService.get_auth_credentials(tenant_id, "website", provider)
# decrypt api_key # decrypt api_key
api_key = encrypter.decrypt_token(tenant_id=tenant_id, token=credentials.get("config").get("api_key")) api_key = encrypter.decrypt_token(tenant_id=tenant_id, token=credentials.get("config").get("api_key"))
# FIXME data is redefine too many times here, use Any to ease the type checking, fix it later
data: Any

if provider == "firecrawl": if provider == "firecrawl":
crawl_data: list[dict[str, Any]] | None = None
file_key = "website_files/" + job_id + ".txt" file_key = "website_files/" + job_id + ".txt"
if storage.exists(file_key): if storage.exists(file_key):
d = storage.load_once(file_key)
if d:
data = json.loads(d.decode("utf-8"))
stored_data = storage.load_once(file_key)
if stored_data:
crawl_data = json.loads(stored_data.decode("utf-8"))
else: else:
firecrawl_app = FirecrawlApp(api_key=api_key, base_url=credentials.get("config").get("base_url", None)) firecrawl_app = FirecrawlApp(api_key=api_key, base_url=credentials.get("config").get("base_url", None))
result = firecrawl_app.check_crawl_status(job_id) result = firecrawl_app.check_crawl_status(job_id)
if result.get("status") != "completed": if result.get("status") != "completed":
raise ValueError("Crawl job is not completed") raise ValueError("Crawl job is not completed")
data = result.get("data")
if data:
for item in data:
crawl_data = result.get("data")

if crawl_data:
for item in crawl_data:
if item.get("source_url") == url: if item.get("source_url") == url:
return dict(item) return dict(item)
return None return None
raise ValueError("Failed to crawl") raise ValueError("Failed to crawl")
return dict(response.json().get("data", {})) return dict(response.json().get("data", {}))
else: else:
api_key = encrypter.decrypt_token(tenant_id=tenant_id, token=credentials.get("config").get("api_key"))
response = requests.post(
# Get crawl status first
status_response = requests.post(
"https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app", "https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app",
headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}, headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"},
json={"taskId": job_id}, json={"taskId": job_id},
) )
data = response.json().get("data", {})
if data.get("status") != "completed":
status_data = status_response.json().get("data", {})
if status_data.get("status") != "completed":
raise ValueError("Crawl job is not completed") raise ValueError("Crawl job is not completed")


response = requests.post(
# Get processed data
data_response = requests.post(
"https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app", "https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app",
headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}, headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"},
json={"taskId": job_id, "urls": list(data.get("processed", {}).keys())},
json={"taskId": job_id, "urls": list(status_data.get("processed", {}).keys())},
) )
data = response.json().get("data", {})
for item in data.get("processed", {}).values():
processed_data = data_response.json().get("data", {})
for item in processed_data.get("processed", {}).values():
if item.get("data", {}).get("url") == url: if item.get("data", {}).get("url") == url:
return dict(item.get("data", {})) return dict(item.get("data", {}))
return None return None

Загрузка…
Отмена
Сохранить