소스 검색

Improve: support custom model parameters in auto-generator (#22924)

tags/1.7.1
quicksand 3 달 전
부모
커밋
8340d775bd
No account linked to committer's email address
2개의 변경된 파일4개의 추가작업 그리고 20개의 파일을 삭제
  1. 0
    7
      api/controllers/console/app/generator.py
  2. 4
    13
      api/core/llm_generator/llm_generator.py

+ 0
- 7
api/controllers/console/app/generator.py 파일 보기

import os

from flask_login import current_user from flask_login import current_user
from flask_restful import Resource, reqparse from flask_restful import Resource, reqparse


args = parser.parse_args() args = parser.parse_args()


account = current_user account = current_user
PROMPT_GENERATION_MAX_TOKENS = int(os.getenv("PROMPT_GENERATION_MAX_TOKENS", "512"))

try: try:
rules = LLMGenerator.generate_rule_config( rules = LLMGenerator.generate_rule_config(
tenant_id=account.current_tenant_id, tenant_id=account.current_tenant_id,
instruction=args["instruction"], instruction=args["instruction"],
model_config=args["model_config"], model_config=args["model_config"],
no_variable=args["no_variable"], no_variable=args["no_variable"],
rule_config_max_tokens=PROMPT_GENERATION_MAX_TOKENS,
) )
except ProviderTokenNotInitError as ex: except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description) raise ProviderNotInitializeError(ex.description)
args = parser.parse_args() args = parser.parse_args()


account = current_user account = current_user
CODE_GENERATION_MAX_TOKENS = int(os.getenv("CODE_GENERATION_MAX_TOKENS", "1024"))
try: try:
code_result = LLMGenerator.generate_code( code_result = LLMGenerator.generate_code(
tenant_id=account.current_tenant_id, tenant_id=account.current_tenant_id,
instruction=args["instruction"], instruction=args["instruction"],
model_config=args["model_config"], model_config=args["model_config"],
code_language=args["code_language"], code_language=args["code_language"],
max_tokens=CODE_GENERATION_MAX_TOKENS,
) )
except ProviderTokenNotInitError as ex: except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description) raise ProviderNotInitializeError(ex.description)

+ 4
- 13
api/core/llm_generator/llm_generator.py 파일 보기

return questions return questions


@classmethod @classmethod
def generate_rule_config(
cls, tenant_id: str, instruction: str, model_config: dict, no_variable: bool, rule_config_max_tokens: int = 512
) -> dict:
def generate_rule_config(cls, tenant_id: str, instruction: str, model_config: dict, no_variable: bool) -> dict:
output_parser = RuleConfigGeneratorOutputParser() output_parser = RuleConfigGeneratorOutputParser()


error = "" error = ""
error_step = "" error_step = ""
rule_config = {"prompt": "", "variables": [], "opening_statement": "", "error": ""} rule_config = {"prompt": "", "variables": [], "opening_statement": "", "error": ""}
model_parameters = {"max_tokens": rule_config_max_tokens, "temperature": 0.01}

model_parameters = model_config.get("completion_params", {})
if no_variable: if no_variable:
prompt_template = PromptTemplateParser(WORKFLOW_RULE_CONFIG_PROMPT_GENERATE_TEMPLATE) prompt_template = PromptTemplateParser(WORKFLOW_RULE_CONFIG_PROMPT_GENERATE_TEMPLATE)




@classmethod @classmethod
def generate_code( def generate_code(
cls,
tenant_id: str,
instruction: str,
model_config: dict,
code_language: str = "javascript",
max_tokens: int = 1000,
cls, tenant_id: str, instruction: str, model_config: dict, code_language: str = "javascript"
) -> dict: ) -> dict:
if code_language == "python": if code_language == "python":
prompt_template = PromptTemplateParser(PYTHON_CODE_GENERATOR_PROMPT_TEMPLATE) prompt_template = PromptTemplateParser(PYTHON_CODE_GENERATOR_PROMPT_TEMPLATE)
) )


prompt_messages = [UserPromptMessage(content=prompt)] prompt_messages = [UserPromptMessage(content=prompt)]
model_parameters = {"max_tokens": max_tokens, "temperature": 0.01}

model_parameters = model_config.get("completion_params", {})
try: try:
response = cast( response = cast(
LLMResult, LLMResult,

Loading…
취소
저장