| @@ -1,5 +1,3 @@ | |||
| import os | |||
| from flask_login import current_user | |||
| from flask_restful import Resource, reqparse | |||
| @@ -29,15 +27,12 @@ class RuleGenerateApi(Resource): | |||
| args = parser.parse_args() | |||
| account = current_user | |||
| PROMPT_GENERATION_MAX_TOKENS = int(os.getenv("PROMPT_GENERATION_MAX_TOKENS", "512")) | |||
| try: | |||
| rules = LLMGenerator.generate_rule_config( | |||
| tenant_id=account.current_tenant_id, | |||
| instruction=args["instruction"], | |||
| model_config=args["model_config"], | |||
| no_variable=args["no_variable"], | |||
| rule_config_max_tokens=PROMPT_GENERATION_MAX_TOKENS, | |||
| ) | |||
| except ProviderTokenNotInitError as ex: | |||
| raise ProviderNotInitializeError(ex.description) | |||
| @@ -64,14 +59,12 @@ class RuleCodeGenerateApi(Resource): | |||
| args = parser.parse_args() | |||
| account = current_user | |||
| CODE_GENERATION_MAX_TOKENS = int(os.getenv("CODE_GENERATION_MAX_TOKENS", "1024")) | |||
| try: | |||
| code_result = LLMGenerator.generate_code( | |||
| tenant_id=account.current_tenant_id, | |||
| instruction=args["instruction"], | |||
| model_config=args["model_config"], | |||
| code_language=args["code_language"], | |||
| max_tokens=CODE_GENERATION_MAX_TOKENS, | |||
| ) | |||
| except ProviderTokenNotInitError as ex: | |||
| raise ProviderNotInitializeError(ex.description) | |||
| @@ -125,16 +125,13 @@ class LLMGenerator: | |||
| return questions | |||
| @classmethod | |||
| def generate_rule_config( | |||
| cls, tenant_id: str, instruction: str, model_config: dict, no_variable: bool, rule_config_max_tokens: int = 512 | |||
| ) -> dict: | |||
| def generate_rule_config(cls, tenant_id: str, instruction: str, model_config: dict, no_variable: bool) -> dict: | |||
| output_parser = RuleConfigGeneratorOutputParser() | |||
| error = "" | |||
| error_step = "" | |||
| rule_config = {"prompt": "", "variables": [], "opening_statement": "", "error": ""} | |||
| model_parameters = {"max_tokens": rule_config_max_tokens, "temperature": 0.01} | |||
| model_parameters = model_config.get("completion_params", {}) | |||
| if no_variable: | |||
| prompt_template = PromptTemplateParser(WORKFLOW_RULE_CONFIG_PROMPT_GENERATE_TEMPLATE) | |||
| @@ -276,12 +273,7 @@ class LLMGenerator: | |||
| @classmethod | |||
| def generate_code( | |||
| cls, | |||
| tenant_id: str, | |||
| instruction: str, | |||
| model_config: dict, | |||
| code_language: str = "javascript", | |||
| max_tokens: int = 1000, | |||
| cls, tenant_id: str, instruction: str, model_config: dict, code_language: str = "javascript" | |||
| ) -> dict: | |||
| if code_language == "python": | |||
| prompt_template = PromptTemplateParser(PYTHON_CODE_GENERATOR_PROMPT_TEMPLATE) | |||
| @@ -305,8 +297,7 @@ class LLMGenerator: | |||
| ) | |||
| prompt_messages = [UserPromptMessage(content=prompt)] | |||
| model_parameters = {"max_tokens": max_tokens, "temperature": 0.01} | |||
| model_parameters = model_config.get("completion_params", {}) | |||
| try: | |||
| response = cast( | |||
| LLMResult, | |||