Nelze vybrat více než 25 témat Téma musí začínat písmenem nebo číslem, může obsahovat pomlčky („-“) a může být dlouhé až 35 znaků.

code.py 5.6KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152
  1. #
  2. # Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import base64
  17. from abc import ABC
  18. from enum import Enum
  19. from typing import Optional
  20. from pydantic import BaseModel, Field, field_validator
  21. from agent.component.base import ComponentBase, ComponentParamBase
  22. from api import settings
  23. class Language(str, Enum):
  24. PYTHON = "python"
  25. NODEJS = "nodejs"
  26. class CodeExecutionRequest(BaseModel):
  27. code_b64: str = Field(..., description="Base64 encoded code string")
  28. language: Language = Field(default=Language.PYTHON, description="Programming language")
  29. arguments: Optional[dict] = Field(default={}, description="Arguments")
  30. @field_validator("code_b64")
  31. @classmethod
  32. def validate_base64(cls, v: str) -> str:
  33. try:
  34. base64.b64decode(v, validate=True)
  35. return v
  36. except Exception as e:
  37. raise ValueError(f"Invalid base64 encoding: {str(e)}")
  38. @field_validator("language", mode="before")
  39. @classmethod
  40. def normalize_language(cls, v) -> str:
  41. if isinstance(v, str):
  42. low = v.lower()
  43. if low in ("python", "python3"):
  44. return "python"
  45. elif low in ("javascript", "nodejs"):
  46. return "nodejs"
  47. raise ValueError(f"Unsupported language: {v}")
  48. class CodeParam(ComponentParamBase):
  49. """
  50. Define the code sandbox component parameters.
  51. """
  52. def __init__(self):
  53. super().__init__()
  54. self.lang = "python"
  55. self.script = ""
  56. self.arguments = []
  57. self.address = f"http://{settings.SANDBOX_HOST}:9385/run"
  58. self.enable_network = True
  59. def check(self):
  60. self.check_valid_value(self.lang, "Support languages", ["python", "python3", "nodejs", "javascript"])
  61. self.check_defined_type(self.enable_network, "Enable network", ["bool"])
  62. class Code(ComponentBase, ABC):
  63. component_name = "Code"
  64. def _run(self, history, **kwargs):
  65. arguments = {}
  66. for input in self._param.arguments:
  67. if "@" in input["component_id"]:
  68. component_id = input["component_id"].split("@")[0]
  69. referred_component_key = input["component_id"].split("@")[1]
  70. referred_component = self._canvas.get_component(component_id)["obj"]
  71. for param in referred_component._param.query:
  72. if param["key"] == referred_component_key:
  73. if "value" in param:
  74. arguments[input["name"]] = param["value"]
  75. else:
  76. referred_component = self._canvas.get_component(input["component_id"])["obj"]
  77. referred_component_name = referred_component.component_name
  78. referred_component_id = referred_component._id
  79. debug_inputs = self._param.debug_inputs
  80. if debug_inputs:
  81. for param in debug_inputs:
  82. if param["key"] == referred_component_id:
  83. if "value" in param and param["name"] == input["name"]:
  84. arguments[input["name"]] = param["value"]
  85. else:
  86. if referred_component_name.lower() == "answer":
  87. arguments[input["name"]] = self._canvas.get_history(1)[0]["content"]
  88. continue
  89. _, out = referred_component.output(allow_partial=False)
  90. if not out.empty:
  91. arguments[input["name"]] = "\n".join(out["content"])
  92. return self._execute_code(
  93. language=self._param.lang,
  94. code=self._param.script,
  95. arguments=arguments,
  96. address=self._param.address,
  97. enable_network=self._param.enable_network,
  98. )
  99. def _execute_code(self, language: str, code: str, arguments: dict, address: str, enable_network: bool):
  100. import requests
  101. try:
  102. code_b64 = self._encode_code(code)
  103. code_req = CodeExecutionRequest(code_b64=code_b64, language=language, arguments=arguments).model_dump()
  104. except Exception as e:
  105. return Code.be_output("**Error**: construct code request error: " + str(e))
  106. try:
  107. resp = requests.post(url=address, json=code_req, timeout=10)
  108. body = resp.json()
  109. if body:
  110. stdout = body.get("stdout")
  111. stderr = body.get("stderr")
  112. return Code.be_output(stdout or stderr)
  113. else:
  114. return Code.be_output("**Error**: There is no response from sanbox")
  115. except Exception as e:
  116. return Code.be_output("**Error**: Internal error in sanbox: " + str(e))
  117. def _encode_code(self, code: str) -> str:
  118. return base64.b64encode(code.encode("utf-8")).decode("utf-8")
  119. def get_input_elements(self):
  120. elements = []
  121. for input in self._param.arguments:
  122. cpn_id = input["component_id"]
  123. elements.append({"key": cpn_id, "name": input["name"]})
  124. return elements
  125. def debug(self, **kwargs):
  126. return self._run([], **kwargs)