Co-authored-by: takatost <takatost@gmail.com>tags/0.3.27
| @@ -31,6 +31,7 @@ model_templates = { | |||
| 'model': json.dumps({ | |||
| "provider": "openai", | |||
| "name": "gpt-3.5-turbo-instruct", | |||
| "mode": "completion", | |||
| "completion_params": { | |||
| "max_tokens": 512, | |||
| "temperature": 1, | |||
| @@ -81,6 +82,7 @@ model_templates = { | |||
| 'model': json.dumps({ | |||
| "provider": "openai", | |||
| "name": "gpt-3.5-turbo", | |||
| "mode": "chat", | |||
| "completion_params": { | |||
| "max_tokens": 512, | |||
| "temperature": 1, | |||
| @@ -137,10 +139,11 @@ demo_model_templates = { | |||
| }, | |||
| opening_statement='', | |||
| suggested_questions=None, | |||
| pre_prompt="Please translate the following text into {{target_language}}:\n", | |||
| pre_prompt="Please translate the following text into {{target_language}}:\n{{query}}\ntranslate:", | |||
| model=json.dumps({ | |||
| "provider": "openai", | |||
| "name": "gpt-3.5-turbo-instruct", | |||
| "mode": "completion", | |||
| "completion_params": { | |||
| "max_tokens": 1000, | |||
| "temperature": 0, | |||
| @@ -169,6 +172,13 @@ demo_model_templates = { | |||
| 'Italian', | |||
| ] | |||
| } | |||
| },{ | |||
| "paragraph": { | |||
| "label": "Query", | |||
| "variable": "query", | |||
| "required": True, | |||
| "default": "" | |||
| } | |||
| } | |||
| ]) | |||
| ) | |||
| @@ -200,6 +210,7 @@ demo_model_templates = { | |||
| model=json.dumps({ | |||
| "provider": "openai", | |||
| "name": "gpt-3.5-turbo", | |||
| "mode": "chat", | |||
| "completion_params": { | |||
| "max_tokens": 300, | |||
| "temperature": 0.8, | |||
| @@ -255,10 +266,11 @@ demo_model_templates = { | |||
| }, | |||
| opening_statement='', | |||
| suggested_questions=None, | |||
| pre_prompt="请将以下文本翻译为{{target_language}}:\n", | |||
| pre_prompt="请将以下文本翻译为{{target_language}}:\n{{query}}\n翻译:", | |||
| model=json.dumps({ | |||
| "provider": "openai", | |||
| "name": "gpt-3.5-turbo-instruct", | |||
| "mode": "completion", | |||
| "completion_params": { | |||
| "max_tokens": 1000, | |||
| "temperature": 0, | |||
| @@ -287,6 +299,13 @@ demo_model_templates = { | |||
| "意大利语", | |||
| ] | |||
| } | |||
| },{ | |||
| "paragraph": { | |||
| "label": "文本内容", | |||
| "variable": "query", | |||
| "required": True, | |||
| "default": "" | |||
| } | |||
| } | |||
| ]) | |||
| ) | |||
| @@ -318,6 +337,7 @@ demo_model_templates = { | |||
| model=json.dumps({ | |||
| "provider": "openai", | |||
| "name": "gpt-3.5-turbo", | |||
| "mode": "chat", | |||
| "completion_params": { | |||
| "max_tokens": 300, | |||
| "temperature": 0.8, | |||
| @@ -9,7 +9,7 @@ api = ExternalApi(bp) | |||
| from . import setup, version, apikey, admin | |||
| # Import app controllers | |||
| from .app import app, site, completion, model_config, statistic, conversation, message, generator, audio | |||
| from .app import advanced_prompt_template, app, site, completion, model_config, statistic, conversation, message, generator, audio | |||
| # Import auth controllers | |||
| from .auth import login, oauth, data_source_oauth, activate | |||
| @@ -0,0 +1,26 @@ | |||
| from flask_restful import Resource, reqparse | |||
| from controllers.console import api | |||
| from controllers.console.setup import setup_required | |||
| from controllers.console.wraps import account_initialization_required | |||
| from libs.login import login_required | |||
| from services.advanced_prompt_template_service import AdvancedPromptTemplateService | |||
| class AdvancedPromptTemplateList(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| def get(self): | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('app_mode', type=str, required=True, location='args') | |||
| parser.add_argument('model_mode', type=str, required=True, location='args') | |||
| parser.add_argument('has_context', type=str, required=False, default='true', location='args') | |||
| parser.add_argument('model_name', type=str, required=True, location='args') | |||
| args = parser.parse_args() | |||
| service = AdvancedPromptTemplateService() | |||
| return service.get_prompt(args) | |||
| api.add_resource(AdvancedPromptTemplateList, '/app/prompt-templates') | |||
| @@ -12,35 +12,6 @@ from core.model_providers.error import ProviderTokenNotInitError, QuotaExceededE | |||
| LLMAPIUnavailableError, LLMRateLimitError, LLMAuthorizationError, ModelCurrentlyNotSupportError | |||
| class IntroductionGenerateApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| def post(self): | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument('prompt_template', type=str, required=True, location='json') | |||
| args = parser.parse_args() | |||
| account = current_user | |||
| try: | |||
| answer = LLMGenerator.generate_introduction( | |||
| account.current_tenant_id, | |||
| args['prompt_template'] | |||
| ) | |||
| except ProviderTokenNotInitError as ex: | |||
| raise ProviderNotInitializeError(ex.description) | |||
| except QuotaExceededError: | |||
| raise ProviderQuotaExceededError() | |||
| except ModelCurrentlyNotSupportError: | |||
| raise ProviderModelCurrentlyNotSupportError() | |||
| except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, | |||
| LLMRateLimitError, LLMAuthorizationError) as e: | |||
| raise CompletionRequestError(str(e)) | |||
| return {'introduction': answer} | |||
| class RuleGenerateApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @@ -72,5 +43,4 @@ class RuleGenerateApi(Resource): | |||
| return rules | |||
| api.add_resource(IntroductionGenerateApi, '/introduction-generate') | |||
| api.add_resource(RuleGenerateApi, '/rule-generate') | |||
| @@ -329,7 +329,7 @@ class MessageApi(Resource): | |||
| message_id = str(message_id) | |||
| # get app info | |||
| app_model = _get_app(app_id, 'chat') | |||
| app_model = _get_app(app_id) | |||
| message = db.session.query(Message).filter( | |||
| Message.id == message_id, | |||
| @@ -115,7 +115,7 @@ class MessageMoreLikeThisApi(WebApiResource): | |||
| streaming = args['response_mode'] == 'streaming' | |||
| try: | |||
| response = CompletionService.generate_more_like_this(app_model, end_user, message_id, streaming) | |||
| response = CompletionService.generate_more_like_this(app_model, end_user, message_id, streaming, 'web_app') | |||
| return compact_response(response) | |||
| except MessageNotExistsError: | |||
| raise NotFound("Message Not Exists.") | |||
| @@ -1,4 +1,3 @@ | |||
| import json | |||
| import logging | |||
| from typing import Optional, List, Union | |||
| @@ -16,10 +15,8 @@ from core.model_providers.model_factory import ModelFactory | |||
| from core.model_providers.models.entity.message import PromptMessage | |||
| from core.model_providers.models.llm.base import BaseLLM | |||
| from core.orchestrator_rule_parser import OrchestratorRuleParser | |||
| from core.prompt.prompt_builder import PromptBuilder | |||
| from core.prompt.prompts import MORE_LIKE_THIS_GENERATE_PROMPT | |||
| from models.dataset import DocumentSegment, Dataset, Document | |||
| from models.model import App, AppModelConfig, Account, Conversation, Message, EndUser | |||
| from core.prompt.prompt_template import PromptTemplateParser | |||
| from models.model import App, AppModelConfig, Account, Conversation, EndUser | |||
| class Completion: | |||
| @@ -30,7 +27,7 @@ class Completion: | |||
| """ | |||
| errors: ProviderTokenNotInitError | |||
| """ | |||
| query = PromptBuilder.process_template(query) | |||
| query = PromptTemplateParser.remove_template_variables(query) | |||
| memory = None | |||
| if conversation: | |||
| @@ -160,14 +157,28 @@ class Completion: | |||
| memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory], | |||
| fake_response: Optional[str]): | |||
| # get llm prompt | |||
| prompt_messages, stop_words = model_instance.get_prompt( | |||
| mode=mode, | |||
| pre_prompt=app_model_config.pre_prompt, | |||
| inputs=inputs, | |||
| query=query, | |||
| context=agent_execute_result.output if agent_execute_result else None, | |||
| memory=memory | |||
| ) | |||
| if app_model_config.prompt_type == 'simple': | |||
| prompt_messages, stop_words = model_instance.get_prompt( | |||
| mode=mode, | |||
| pre_prompt=app_model_config.pre_prompt, | |||
| inputs=inputs, | |||
| query=query, | |||
| context=agent_execute_result.output if agent_execute_result else None, | |||
| memory=memory | |||
| ) | |||
| else: | |||
| prompt_messages = model_instance.get_advanced_prompt( | |||
| app_mode=mode, | |||
| app_model_config=app_model_config, | |||
| inputs=inputs, | |||
| query=query, | |||
| context=agent_execute_result.output if agent_execute_result else None, | |||
| memory=memory | |||
| ) | |||
| model_config = app_model_config.model_dict | |||
| completion_params = model_config.get("completion_params", {}) | |||
| stop_words = completion_params.get("stop", []) | |||
| cls.recale_llm_max_tokens( | |||
| model_instance=model_instance, | |||
| @@ -176,7 +187,7 @@ class Completion: | |||
| response = model_instance.run( | |||
| messages=prompt_messages, | |||
| stop=stop_words, | |||
| stop=stop_words if stop_words else None, | |||
| callbacks=[LLMCallbackHandler(model_instance, conversation_message_task)], | |||
| fake_response=fake_response | |||
| ) | |||
| @@ -266,52 +277,3 @@ class Completion: | |||
| model_kwargs = model_instance.get_model_kwargs() | |||
| model_kwargs.max_tokens = max_tokens | |||
| model_instance.set_model_kwargs(model_kwargs) | |||
| @classmethod | |||
| def generate_more_like_this(cls, task_id: str, app: App, message: Message, pre_prompt: str, | |||
| app_model_config: AppModelConfig, user: Account, streaming: bool): | |||
| final_model_instance = ModelFactory.get_text_generation_model_from_model_config( | |||
| tenant_id=app.tenant_id, | |||
| model_config=app_model_config.model_dict, | |||
| streaming=streaming | |||
| ) | |||
| # get llm prompt | |||
| old_prompt_messages, _ = final_model_instance.get_prompt( | |||
| mode='completion', | |||
| pre_prompt=pre_prompt, | |||
| inputs=message.inputs, | |||
| query=message.query, | |||
| context=None, | |||
| memory=None | |||
| ) | |||
| original_completion = message.answer.strip() | |||
| prompt = MORE_LIKE_THIS_GENERATE_PROMPT | |||
| prompt = prompt.format(prompt=old_prompt_messages[0].content, original_completion=original_completion) | |||
| prompt_messages = [PromptMessage(content=prompt)] | |||
| conversation_message_task = ConversationMessageTask( | |||
| task_id=task_id, | |||
| app=app, | |||
| app_model_config=app_model_config, | |||
| user=user, | |||
| inputs=message.inputs, | |||
| query=message.query, | |||
| is_override=True if message.override_model_configs else False, | |||
| streaming=streaming, | |||
| model_instance=final_model_instance | |||
| ) | |||
| cls.recale_llm_max_tokens( | |||
| model_instance=final_model_instance, | |||
| prompt_messages=prompt_messages | |||
| ) | |||
| final_model_instance.run( | |||
| messages=prompt_messages, | |||
| callbacks=[LLMCallbackHandler(final_model_instance, conversation_message_task)] | |||
| ) | |||
| @@ -10,7 +10,7 @@ from core.model_providers.model_factory import ModelFactory | |||
| from core.model_providers.models.entity.message import to_prompt_messages, MessageType | |||
| from core.model_providers.models.llm.base import BaseLLM | |||
| from core.prompt.prompt_builder import PromptBuilder | |||
| from core.prompt.prompt_template import JinjaPromptTemplate | |||
| from core.prompt.prompt_template import PromptTemplateParser | |||
| from events.message_event import message_was_created | |||
| from extensions.ext_database import db | |||
| from extensions.ext_redis import redis_client | |||
| @@ -74,10 +74,10 @@ class ConversationMessageTask: | |||
| if self.mode == 'chat': | |||
| introduction = self.app_model_config.opening_statement | |||
| if introduction: | |||
| prompt_template = JinjaPromptTemplate.from_template(template=introduction) | |||
| prompt_inputs = {k: self.inputs[k] for k in prompt_template.input_variables if k in self.inputs} | |||
| prompt_template = PromptTemplateParser(template=introduction) | |||
| prompt_inputs = {k: self.inputs[k] for k in prompt_template.variable_keys if k in self.inputs} | |||
| try: | |||
| introduction = prompt_template.format(**prompt_inputs) | |||
| introduction = prompt_template.format(prompt_inputs) | |||
| except KeyError: | |||
| pass | |||
| @@ -150,12 +150,12 @@ class ConversationMessageTask: | |||
| message_tokens = llm_message.prompt_tokens | |||
| answer_tokens = llm_message.completion_tokens | |||
| message_unit_price = self.model_instance.get_tokens_unit_price(MessageType.HUMAN) | |||
| message_price_unit = self.model_instance.get_price_unit(MessageType.HUMAN) | |||
| message_unit_price = self.model_instance.get_tokens_unit_price(MessageType.USER) | |||
| message_price_unit = self.model_instance.get_price_unit(MessageType.USER) | |||
| answer_unit_price = self.model_instance.get_tokens_unit_price(MessageType.ASSISTANT) | |||
| answer_price_unit = self.model_instance.get_price_unit(MessageType.ASSISTANT) | |||
| message_total_price = self.model_instance.calc_tokens_price(message_tokens, MessageType.HUMAN) | |||
| message_total_price = self.model_instance.calc_tokens_price(message_tokens, MessageType.USER) | |||
| answer_total_price = self.model_instance.calc_tokens_price(answer_tokens, MessageType.ASSISTANT) | |||
| total_price = message_total_price + answer_total_price | |||
| @@ -163,7 +163,7 @@ class ConversationMessageTask: | |||
| self.message.message_tokens = message_tokens | |||
| self.message.message_unit_price = message_unit_price | |||
| self.message.message_price_unit = message_price_unit | |||
| self.message.answer = PromptBuilder.process_template( | |||
| self.message.answer = PromptTemplateParser.remove_template_variables( | |||
| llm_message.completion.strip()) if llm_message.completion else '' | |||
| self.message.answer_tokens = answer_tokens | |||
| self.message.answer_unit_price = answer_unit_price | |||
| @@ -226,15 +226,15 @@ class ConversationMessageTask: | |||
| def on_agent_end(self, message_agent_thought: MessageAgentThought, agent_model_instance: BaseLLM, | |||
| agent_loop: AgentLoop): | |||
| agent_message_unit_price = agent_model_instance.get_tokens_unit_price(MessageType.HUMAN) | |||
| agent_message_price_unit = agent_model_instance.get_price_unit(MessageType.HUMAN) | |||
| agent_message_unit_price = agent_model_instance.get_tokens_unit_price(MessageType.USER) | |||
| agent_message_price_unit = agent_model_instance.get_price_unit(MessageType.USER) | |||
| agent_answer_unit_price = agent_model_instance.get_tokens_unit_price(MessageType.ASSISTANT) | |||
| agent_answer_price_unit = agent_model_instance.get_price_unit(MessageType.ASSISTANT) | |||
| loop_message_tokens = agent_loop.prompt_tokens | |||
| loop_answer_tokens = agent_loop.completion_tokens | |||
| loop_message_total_price = agent_model_instance.calc_tokens_price(loop_message_tokens, MessageType.HUMAN) | |||
| loop_message_total_price = agent_model_instance.calc_tokens_price(loop_message_tokens, MessageType.USER) | |||
| loop_answer_total_price = agent_model_instance.calc_tokens_price(loop_answer_tokens, MessageType.ASSISTANT) | |||
| loop_total_price = loop_message_total_price + loop_answer_total_price | |||
| @@ -10,9 +10,8 @@ from core.model_providers.models.entity.model_params import ModelKwargs | |||
| from core.prompt.output_parser.rule_config_generator import RuleConfigGeneratorOutputParser | |||
| from core.prompt.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser | |||
| from core.prompt.prompt_template import JinjaPromptTemplate, OutLinePromptTemplate | |||
| from core.prompt.prompts import CONVERSATION_TITLE_PROMPT, CONVERSATION_SUMMARY_PROMPT, INTRODUCTION_GENERATE_PROMPT, \ | |||
| GENERATOR_QA_PROMPT | |||
| from core.prompt.prompt_template import PromptTemplateParser | |||
| from core.prompt.prompts import CONVERSATION_TITLE_PROMPT, GENERATOR_QA_PROMPT | |||
| class LLMGenerator: | |||
| @@ -44,78 +43,19 @@ class LLMGenerator: | |||
| return answer.strip() | |||
| @classmethod | |||
| def generate_conversation_summary(cls, tenant_id: str, messages): | |||
| max_tokens = 200 | |||
| model_instance = ModelFactory.get_text_generation_model( | |||
| tenant_id=tenant_id, | |||
| model_kwargs=ModelKwargs( | |||
| max_tokens=max_tokens | |||
| ) | |||
| ) | |||
| prompt = CONVERSATION_SUMMARY_PROMPT | |||
| prompt_with_empty_context = prompt.format(context='') | |||
| prompt_tokens = model_instance.get_num_tokens([PromptMessage(content=prompt_with_empty_context)]) | |||
| max_context_token_length = model_instance.model_rules.max_tokens.max | |||
| max_context_token_length = max_context_token_length if max_context_token_length else 1500 | |||
| rest_tokens = max_context_token_length - prompt_tokens - max_tokens - 1 | |||
| context = '' | |||
| for message in messages: | |||
| if not message.answer: | |||
| continue | |||
| if len(message.query) > 2000: | |||
| query = message.query[:300] + "...[TRUNCATED]..." + message.query[-300:] | |||
| else: | |||
| query = message.query | |||
| if len(message.answer) > 2000: | |||
| answer = message.answer[:300] + "...[TRUNCATED]..." + message.answer[-300:] | |||
| else: | |||
| answer = message.answer | |||
| message_qa_text = "\n\nHuman:" + query + "\n\nAssistant:" + answer | |||
| if rest_tokens - model_instance.get_num_tokens([PromptMessage(content=context + message_qa_text)]) > 0: | |||
| context += message_qa_text | |||
| if not context: | |||
| return '[message too long, no summary]' | |||
| prompt = prompt.format(context=context) | |||
| prompts = [PromptMessage(content=prompt)] | |||
| response = model_instance.run(prompts) | |||
| answer = response.content | |||
| return answer.strip() | |||
| @classmethod | |||
| def generate_introduction(cls, tenant_id: str, pre_prompt: str): | |||
| prompt = INTRODUCTION_GENERATE_PROMPT | |||
| prompt = prompt.format(prompt=pre_prompt) | |||
| model_instance = ModelFactory.get_text_generation_model( | |||
| tenant_id=tenant_id | |||
| ) | |||
| prompts = [PromptMessage(content=prompt)] | |||
| response = model_instance.run(prompts) | |||
| answer = response.content | |||
| return answer.strip() | |||
| @classmethod | |||
| def generate_suggested_questions_after_answer(cls, tenant_id: str, histories: str): | |||
| output_parser = SuggestedQuestionsAfterAnswerOutputParser() | |||
| format_instructions = output_parser.get_format_instructions() | |||
| prompt = JinjaPromptTemplate( | |||
| template="{{histories}}\n{{format_instructions}}\nquestions:\n", | |||
| input_variables=["histories"], | |||
| partial_variables={"format_instructions": format_instructions} | |||
| prompt_template = PromptTemplateParser( | |||
| template="{{histories}}\n{{format_instructions}}\nquestions:\n" | |||
| ) | |||
| _input = prompt.format_prompt(histories=histories) | |||
| prompt = prompt_template.format({ | |||
| "histories": histories, | |||
| "format_instructions": format_instructions | |||
| }) | |||
| try: | |||
| model_instance = ModelFactory.get_text_generation_model( | |||
| @@ -128,10 +68,10 @@ class LLMGenerator: | |||
| except ProviderTokenNotInitError: | |||
| return [] | |||
| prompts = [PromptMessage(content=_input.to_string())] | |||
| prompt_messages = [PromptMessage(content=prompt)] | |||
| try: | |||
| output = model_instance.run(prompts) | |||
| output = model_instance.run(prompt_messages) | |||
| questions = output_parser.parse(output.content) | |||
| except LLMError: | |||
| questions = [] | |||
| @@ -145,19 +85,21 @@ class LLMGenerator: | |||
| def generate_rule_config(cls, tenant_id: str, audiences: str, hoping_to_solve: str) -> dict: | |||
| output_parser = RuleConfigGeneratorOutputParser() | |||
| prompt = OutLinePromptTemplate( | |||
| template=output_parser.get_format_instructions(), | |||
| input_variables=["audiences", "hoping_to_solve"], | |||
| partial_variables={ | |||
| "variable": '{variable}', | |||
| "lanA": '{lanA}', | |||
| "lanB": '{lanB}', | |||
| "topic": '{topic}' | |||
| }, | |||
| validate_template=False | |||
| prompt_template = PromptTemplateParser( | |||
| template=output_parser.get_format_instructions() | |||
| ) | |||
| _input = prompt.format_prompt(audiences=audiences, hoping_to_solve=hoping_to_solve) | |||
| prompt = prompt_template.format( | |||
| inputs={ | |||
| "audiences": audiences, | |||
| "hoping_to_solve": hoping_to_solve, | |||
| "variable": "{{variable}}", | |||
| "lanA": "{{lanA}}", | |||
| "lanB": "{{lanB}}", | |||
| "topic": "{{topic}}" | |||
| }, | |||
| remove_template_variables=False | |||
| ) | |||
| model_instance = ModelFactory.get_text_generation_model( | |||
| tenant_id=tenant_id, | |||
| @@ -167,10 +109,10 @@ class LLMGenerator: | |||
| ) | |||
| ) | |||
| prompts = [PromptMessage(content=_input.to_string())] | |||
| prompt_messages = [PromptMessage(content=prompt)] | |||
| try: | |||
| output = model_instance.run(prompts) | |||
| output = model_instance.run(prompt_messages) | |||
| rule_config = output_parser.parse(output.content) | |||
| except LLMError as e: | |||
| raise e | |||
| @@ -286,7 +286,7 @@ class IndexingRunner: | |||
| "total_segments": total_segments * 20, | |||
| "tokens": total_segments * 2000, | |||
| "total_price": '{:f}'.format( | |||
| text_generation_model.calc_tokens_price(total_segments * 2000, MessageType.HUMAN)), | |||
| text_generation_model.calc_tokens_price(total_segments * 2000, MessageType.USER)), | |||
| "currency": embedding_model.get_currency(), | |||
| "qa_preview": document_qa_list, | |||
| "preview": preview_texts | |||
| @@ -383,7 +383,7 @@ class IndexingRunner: | |||
| "total_segments": total_segments * 20, | |||
| "tokens": total_segments * 2000, | |||
| "total_price": '{:f}'.format( | |||
| text_generation_model.calc_tokens_price(total_segments * 2000, MessageType.HUMAN)), | |||
| text_generation_model.calc_tokens_price(total_segments * 2000, MessageType.USER)), | |||
| "currency": embedding_model.get_currency(), | |||
| "qa_preview": document_qa_list, | |||
| "preview": preview_texts | |||
| @@ -31,7 +31,7 @@ class ReadOnlyConversationTokenDBBufferSharedMemory(BaseChatMemory): | |||
| chat_messages: List[PromptMessage] = [] | |||
| for message in messages: | |||
| chat_messages.append(PromptMessage(content=message.query, type=MessageType.HUMAN)) | |||
| chat_messages.append(PromptMessage(content=message.query, type=MessageType.USER)) | |||
| chat_messages.append(PromptMessage(content=message.answer, type=MessageType.ASSISTANT)) | |||
| if not chat_messages: | |||
| @@ -13,13 +13,13 @@ class LLMRunResult(BaseModel): | |||
| class MessageType(enum.Enum): | |||
| HUMAN = 'human' | |||
| USER = 'user' | |||
| ASSISTANT = 'assistant' | |||
| SYSTEM = 'system' | |||
| class PromptMessage(BaseModel): | |||
| type: MessageType = MessageType.HUMAN | |||
| type: MessageType = MessageType.USER | |||
| content: str = '' | |||
| function_call: dict = None | |||
| @@ -27,7 +27,7 @@ class PromptMessage(BaseModel): | |||
| def to_lc_messages(messages: list[PromptMessage]): | |||
| lc_messages = [] | |||
| for message in messages: | |||
| if message.type == MessageType.HUMAN: | |||
| if message.type == MessageType.USER: | |||
| lc_messages.append(HumanMessage(content=message.content)) | |||
| elif message.type == MessageType.ASSISTANT: | |||
| additional_kwargs = {} | |||
| @@ -44,7 +44,7 @@ def to_prompt_messages(messages: list[BaseMessage]): | |||
| prompt_messages = [] | |||
| for message in messages: | |||
| if isinstance(message, HumanMessage): | |||
| prompt_messages.append(PromptMessage(content=message.content, type=MessageType.HUMAN)) | |||
| prompt_messages.append(PromptMessage(content=message.content, type=MessageType.USER)) | |||
| elif isinstance(message, AIMessage): | |||
| message_kwargs = { | |||
| 'content': message.content, | |||
| @@ -58,7 +58,7 @@ def to_prompt_messages(messages: list[BaseMessage]): | |||
| elif isinstance(message, SystemMessage): | |||
| prompt_messages.append(PromptMessage(content=message.content, type=MessageType.SYSTEM)) | |||
| elif isinstance(message, FunctionMessage): | |||
| prompt_messages.append(PromptMessage(content=message.content, type=MessageType.HUMAN)) | |||
| prompt_messages.append(PromptMessage(content=message.content, type=MessageType.USER)) | |||
| return prompt_messages | |||
| @@ -18,7 +18,7 @@ from core.model_providers.models.entity.message import PromptMessage, MessageTyp | |||
| from core.model_providers.models.entity.model_params import ModelType, ModelKwargs, ModelMode, ModelKwargsRules | |||
| from core.model_providers.providers.base import BaseModelProvider | |||
| from core.prompt.prompt_builder import PromptBuilder | |||
| from core.prompt.prompt_template import JinjaPromptTemplate | |||
| from core.prompt.prompt_template import PromptTemplateParser | |||
| from core.third_party.langchain.llms.fake import FakeLLM | |||
| import logging | |||
| @@ -232,7 +232,7 @@ class BaseLLM(BaseProviderModel): | |||
| :param message_type: | |||
| :return: | |||
| """ | |||
| if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM: | |||
| if message_type == MessageType.USER or message_type == MessageType.SYSTEM: | |||
| unit_price = self.price_config['prompt'] | |||
| else: | |||
| unit_price = self.price_config['completion'] | |||
| @@ -250,7 +250,7 @@ class BaseLLM(BaseProviderModel): | |||
| :param message_type: | |||
| :return: decimal.Decimal('0.0001') | |||
| """ | |||
| if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM: | |||
| if message_type == MessageType.USER or message_type == MessageType.SYSTEM: | |||
| unit_price = self.price_config['prompt'] | |||
| else: | |||
| unit_price = self.price_config['completion'] | |||
| @@ -265,7 +265,7 @@ class BaseLLM(BaseProviderModel): | |||
| :param message_type: | |||
| :return: decimal.Decimal('0.000001') | |||
| """ | |||
| if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM: | |||
| if message_type == MessageType.USER or message_type == MessageType.SYSTEM: | |||
| price_unit = self.price_config['unit'] | |||
| else: | |||
| price_unit = self.price_config['unit'] | |||
| @@ -330,6 +330,85 @@ class BaseLLM(BaseProviderModel): | |||
| prompt, stops = self._get_prompt_and_stop(prompt_rules, pre_prompt, inputs, query, context, memory) | |||
| return [PromptMessage(content=prompt)], stops | |||
| def get_advanced_prompt(self, app_mode: str, | |||
| app_model_config: str, inputs: dict, | |||
| query: str, | |||
| context: Optional[str], | |||
| memory: Optional[BaseChatMemory]) -> List[PromptMessage]: | |||
| model_mode = app_model_config.model_dict['mode'] | |||
| conversation_histories_role = {} | |||
| raw_prompt_list = [] | |||
| prompt_messages = [] | |||
| if app_mode == 'chat' and model_mode == ModelMode.COMPLETION.value: | |||
| prompt_text = app_model_config.completion_prompt_config_dict['prompt']['text'] | |||
| raw_prompt_list = [{ | |||
| 'role': MessageType.USER.value, | |||
| 'text': prompt_text | |||
| }] | |||
| conversation_histories_role = app_model_config.completion_prompt_config_dict['conversation_histories_role'] | |||
| elif app_mode == 'chat' and model_mode == ModelMode.CHAT.value: | |||
| raw_prompt_list = app_model_config.chat_prompt_config_dict['prompt'] | |||
| elif app_mode == 'completion' and model_mode == ModelMode.CHAT.value: | |||
| raw_prompt_list = app_model_config.chat_prompt_config_dict['prompt'] | |||
| elif app_mode == 'completion' and model_mode == ModelMode.COMPLETION.value: | |||
| prompt_text = app_model_config.completion_prompt_config_dict['prompt']['text'] | |||
| raw_prompt_list = [{ | |||
| 'role': MessageType.USER.value, | |||
| 'text': prompt_text | |||
| }] | |||
| else: | |||
| raise Exception("app_mode or model_mode not support") | |||
| for prompt_item in raw_prompt_list: | |||
| prompt = prompt_item['text'] | |||
| # set prompt template variables | |||
| prompt_template = PromptTemplateParser(template=prompt) | |||
| prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} | |||
| if '#context#' in prompt: | |||
| if context: | |||
| prompt_inputs['#context#'] = context | |||
| else: | |||
| prompt_inputs['#context#'] = '' | |||
| if '#query#' in prompt: | |||
| if query: | |||
| prompt_inputs['#query#'] = query | |||
| else: | |||
| prompt_inputs['#query#'] = '' | |||
| if '#histories#' in prompt: | |||
| if memory and app_mode == 'chat' and model_mode == ModelMode.COMPLETION.value: | |||
| memory.human_prefix = conversation_histories_role['user_prefix'] | |||
| memory.ai_prefix = conversation_histories_role['assistant_prefix'] | |||
| histories = self._get_history_messages_from_memory(memory, 2000) | |||
| prompt_inputs['#histories#'] = histories | |||
| else: | |||
| prompt_inputs['#histories#'] = '' | |||
| prompt = prompt_template.format( | |||
| prompt_inputs | |||
| ) | |||
| prompt = re.sub(r'<\|.*?\|>', '', prompt) | |||
| prompt_messages.append(PromptMessage(type = MessageType(prompt_item['role']) ,content=prompt)) | |||
| if memory and app_mode == 'chat' and model_mode == ModelMode.CHAT.value: | |||
| memory.human_prefix = MessageType.USER.value | |||
| memory.ai_prefix = MessageType.ASSISTANT.value | |||
| histories = self._get_history_messages_list_from_memory(memory, 2000) | |||
| prompt_messages.extend(histories) | |||
| if app_mode == 'chat' and model_mode == ModelMode.CHAT.value: | |||
| prompt_messages.append(PromptMessage(type = MessageType.USER ,content=query)) | |||
| return prompt_messages | |||
| def prompt_file_name(self, mode: str) -> str: | |||
| if mode == 'completion': | |||
| return 'common_completion' | |||
| @@ -342,17 +421,17 @@ class BaseLLM(BaseProviderModel): | |||
| memory: Optional[BaseChatMemory]) -> Tuple[str, Optional[list]]: | |||
| context_prompt_content = '' | |||
| if context and 'context_prompt' in prompt_rules: | |||
| prompt_template = JinjaPromptTemplate.from_template(template=prompt_rules['context_prompt']) | |||
| prompt_template = PromptTemplateParser(template=prompt_rules['context_prompt']) | |||
| context_prompt_content = prompt_template.format( | |||
| context=context | |||
| {'context': context} | |||
| ) | |||
| pre_prompt_content = '' | |||
| if pre_prompt: | |||
| prompt_template = JinjaPromptTemplate.from_template(template=pre_prompt) | |||
| prompt_inputs = {k: inputs[k] for k in prompt_template.input_variables if k in inputs} | |||
| prompt_template = PromptTemplateParser(template=pre_prompt) | |||
| prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} | |||
| pre_prompt_content = prompt_template.format( | |||
| **prompt_inputs | |||
| prompt_inputs | |||
| ) | |||
| prompt = '' | |||
| @@ -385,10 +464,8 @@ class BaseLLM(BaseProviderModel): | |||
| memory.ai_prefix = prompt_rules['assistant_prefix'] if 'assistant_prefix' in prompt_rules else 'Assistant' | |||
| histories = self._get_history_messages_from_memory(memory, rest_tokens) | |||
| prompt_template = JinjaPromptTemplate.from_template(template=prompt_rules['histories_prompt']) | |||
| histories_prompt_content = prompt_template.format( | |||
| histories=histories | |||
| ) | |||
| prompt_template = PromptTemplateParser(template=prompt_rules['histories_prompt']) | |||
| histories_prompt_content = prompt_template.format({'histories': histories}) | |||
| prompt = '' | |||
| for order in prompt_rules['system_prompt_orders']: | |||
| @@ -399,10 +476,8 @@ class BaseLLM(BaseProviderModel): | |||
| elif order == 'histories_prompt': | |||
| prompt += histories_prompt_content | |||
| prompt_template = JinjaPromptTemplate.from_template(template=query_prompt) | |||
| query_prompt_content = prompt_template.format( | |||
| query=query | |||
| ) | |||
| prompt_template = PromptTemplateParser(template=query_prompt) | |||
| query_prompt_content = prompt_template.format({'query': query}) | |||
| prompt += query_prompt_content | |||
| @@ -433,6 +508,16 @@ class BaseLLM(BaseProviderModel): | |||
| external_context = memory.load_memory_variables({}) | |||
| return external_context[memory_key] | |||
| def _get_history_messages_list_from_memory(self, memory: BaseChatMemory, | |||
| max_token_limit: int) -> List[PromptMessage]: | |||
| """Get memory messages.""" | |||
| memory.max_token_limit = max_token_limit | |||
| memory.return_messages = True | |||
| memory_key = memory.memory_variables[0] | |||
| external_context = memory.load_memory_variables({}) | |||
| memory.return_messages = False | |||
| return to_prompt_messages(external_context[memory_key]) | |||
| def _get_prompt_from_messages(self, messages: List[PromptMessage], | |||
| model_mode: Optional[ModelMode] = None) -> Union[str | List[BaseMessage]]: | |||
| if not model_mode: | |||
| @@ -9,7 +9,7 @@ from langchain.schema import HumanMessage | |||
| from core.helper import encrypter | |||
| from core.model_providers.models.base import BaseProviderModel | |||
| from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule | |||
| from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelMode | |||
| from core.model_providers.models.entity.provider import ModelFeature | |||
| from core.model_providers.models.llm.anthropic_model import AnthropicModel | |||
| from core.model_providers.models.llm.base import ModelType | |||
| @@ -34,10 +34,12 @@ class AnthropicProvider(BaseModelProvider): | |||
| { | |||
| 'id': 'claude-instant-1', | |||
| 'name': 'claude-instant-1', | |||
| 'mode': ModelMode.CHAT.value, | |||
| }, | |||
| { | |||
| 'id': 'claude-2', | |||
| 'name': 'claude-2', | |||
| 'mode': ModelMode.CHAT.value, | |||
| 'features': [ | |||
| ModelFeature.AGENT_THOUGHT.value | |||
| ] | |||
| @@ -46,6 +48,9 @@ class AnthropicProvider(BaseModelProvider): | |||
| else: | |||
| return [] | |||
| def _get_text_generation_model_mode(self, model_name) -> str: | |||
| return ModelMode.CHAT.value | |||
| def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]: | |||
| """ | |||
| Returns the model class. | |||
| @@ -12,7 +12,7 @@ from core.helper import encrypter | |||
| from core.model_providers.models.base import BaseProviderModel | |||
| from core.model_providers.models.embedding.azure_openai_embedding import AzureOpenAIEmbedding, \ | |||
| AZURE_OPENAI_API_VERSION | |||
| from core.model_providers.models.entity.model_params import ModelType, ModelKwargsRules, KwargRule | |||
| from core.model_providers.models.entity.model_params import ModelType, ModelKwargsRules, KwargRule, ModelMode | |||
| from core.model_providers.models.entity.provider import ModelFeature | |||
| from core.model_providers.models.llm.azure_openai_model import AzureOpenAIModel | |||
| from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError | |||
| @@ -61,6 +61,10 @@ class AzureOpenAIProvider(BaseModelProvider): | |||
| } | |||
| credentials = json.loads(provider_model.encrypted_config) | |||
| if provider_model.model_type == ModelType.TEXT_GENERATION.value: | |||
| model_dict['mode'] = self._get_text_generation_model_mode(credentials['base_model_name']) | |||
| if credentials['base_model_name'] in [ | |||
| 'gpt-4', | |||
| 'gpt-4-32k', | |||
| @@ -77,12 +81,19 @@ class AzureOpenAIProvider(BaseModelProvider): | |||
| return model_list | |||
| def _get_text_generation_model_mode(self, model_name) -> str: | |||
| if model_name == 'text-davinci-003': | |||
| return ModelMode.COMPLETION.value | |||
| else: | |||
| return ModelMode.CHAT.value | |||
| def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]: | |||
| if model_type == ModelType.TEXT_GENERATION: | |||
| models = [ | |||
| { | |||
| 'id': 'gpt-3.5-turbo', | |||
| 'name': 'gpt-3.5-turbo', | |||
| 'mode': ModelMode.CHAT.value, | |||
| 'features': [ | |||
| ModelFeature.AGENT_THOUGHT.value | |||
| ] | |||
| @@ -90,6 +101,7 @@ class AzureOpenAIProvider(BaseModelProvider): | |||
| { | |||
| 'id': 'gpt-3.5-turbo-16k', | |||
| 'name': 'gpt-3.5-turbo-16k', | |||
| 'mode': ModelMode.CHAT.value, | |||
| 'features': [ | |||
| ModelFeature.AGENT_THOUGHT.value | |||
| ] | |||
| @@ -97,6 +109,7 @@ class AzureOpenAIProvider(BaseModelProvider): | |||
| { | |||
| 'id': 'gpt-4', | |||
| 'name': 'gpt-4', | |||
| 'mode': ModelMode.CHAT.value, | |||
| 'features': [ | |||
| ModelFeature.AGENT_THOUGHT.value | |||
| ] | |||
| @@ -104,6 +117,7 @@ class AzureOpenAIProvider(BaseModelProvider): | |||
| { | |||
| 'id': 'gpt-4-32k', | |||
| 'name': 'gpt-4-32k', | |||
| 'mode': ModelMode.CHAT.value, | |||
| 'features': [ | |||
| ModelFeature.AGENT_THOUGHT.value | |||
| ] | |||
| @@ -111,6 +125,7 @@ class AzureOpenAIProvider(BaseModelProvider): | |||
| { | |||
| 'id': 'text-davinci-003', | |||
| 'name': 'text-davinci-003', | |||
| 'mode': ModelMode.COMPLETION.value, | |||
| } | |||
| ] | |||
| @@ -6,7 +6,7 @@ from langchain.schema import HumanMessage | |||
| from core.helper import encrypter | |||
| from core.model_providers.models.base import BaseProviderModel | |||
| from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType | |||
| from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType, ModelMode | |||
| from core.model_providers.models.llm.baichuan_model import BaichuanModel | |||
| from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError | |||
| from core.third_party.langchain.llms.baichuan_llm import BaichuanChatLLM | |||
| @@ -21,6 +21,9 @@ class BaichuanProvider(BaseModelProvider): | |||
| Returns the name of a provider. | |||
| """ | |||
| return 'baichuan' | |||
| def _get_text_generation_model_mode(self, model_name) -> str: | |||
| return ModelMode.CHAT.value | |||
| def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]: | |||
| if model_type == ModelType.TEXT_GENERATION: | |||
| @@ -28,6 +31,7 @@ class BaichuanProvider(BaseModelProvider): | |||
| { | |||
| 'id': 'baichuan2-53b', | |||
| 'name': 'Baichuan2-53B', | |||
| 'mode': ModelMode.CHAT.value, | |||
| } | |||
| ] | |||
| else: | |||
| @@ -61,10 +61,19 @@ class BaseModelProvider(BaseModel, ABC): | |||
| ProviderModel.is_valid == True | |||
| ).order_by(ProviderModel.created_at.asc()).all() | |||
| return [{ | |||
| 'id': provider_model.model_name, | |||
| 'name': provider_model.model_name | |||
| } for provider_model in provider_models] | |||
| provider_model_list = [] | |||
| for provider_model in provider_models: | |||
| provider_model_dict = { | |||
| 'id': provider_model.model_name, | |||
| 'name': provider_model.model_name | |||
| } | |||
| if model_type == ModelType.TEXT_GENERATION: | |||
| provider_model_dict['mode'] = self._get_text_generation_model_mode(provider_model.model_name) | |||
| provider_model_list.append(provider_model_dict) | |||
| return provider_model_list | |||
| @abstractmethod | |||
| def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]: | |||
| @@ -76,6 +85,16 @@ class BaseModelProvider(BaseModel, ABC): | |||
| """ | |||
| raise NotImplementedError | |||
| @abstractmethod | |||
| def _get_text_generation_model_mode(self, model_name) -> str: | |||
| """ | |||
| get text generation model mode. | |||
| :param model_name: | |||
| :return: | |||
| """ | |||
| raise NotImplementedError | |||
| @abstractmethod | |||
| def get_model_class(self, model_type: ModelType) -> Type: | |||
| """ | |||
| @@ -6,7 +6,7 @@ from langchain.llms import ChatGLM | |||
| from core.helper import encrypter | |||
| from core.model_providers.models.base import BaseProviderModel | |||
| from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType | |||
| from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType, ModelMode | |||
| from core.model_providers.models.llm.chatglm_model import ChatGLMModel | |||
| from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError | |||
| from models.provider import ProviderType | |||
| @@ -27,15 +27,20 @@ class ChatGLMProvider(BaseModelProvider): | |||
| { | |||
| 'id': 'chatglm2-6b', | |||
| 'name': 'ChatGLM2-6B', | |||
| 'mode': ModelMode.COMPLETION.value, | |||
| }, | |||
| { | |||
| 'id': 'chatglm-6b', | |||
| 'name': 'ChatGLM-6B', | |||
| 'mode': ModelMode.COMPLETION.value, | |||
| } | |||
| ] | |||
| else: | |||
| return [] | |||
| def _get_text_generation_model_mode(self, model_name) -> str: | |||
| return ModelMode.COMPLETION.value | |||
| def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]: | |||
| """ | |||
| Returns the model class. | |||
| @@ -5,7 +5,7 @@ import requests | |||
| from huggingface_hub import HfApi | |||
| from core.helper import encrypter | |||
| from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType | |||
| from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType, ModelMode | |||
| from core.model_providers.models.llm.huggingface_hub_model import HuggingfaceHubModel | |||
| from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError | |||
| @@ -29,6 +29,9 @@ class HuggingfaceHubProvider(BaseModelProvider): | |||
| def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]: | |||
| return [] | |||
| def _get_text_generation_model_mode(self, model_name) -> str: | |||
| return ModelMode.COMPLETION.value | |||
| def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]: | |||
| """ | |||
| Returns the model class. | |||
| @@ -6,7 +6,7 @@ from langchain.schema import HumanMessage | |||
| from core.helper import encrypter | |||
| from core.model_providers.models.embedding.localai_embedding import LocalAIEmbedding | |||
| from core.model_providers.models.entity.model_params import ModelKwargsRules, ModelType, KwargRule | |||
| from core.model_providers.models.entity.model_params import ModelKwargsRules, ModelType, KwargRule, ModelMode | |||
| from core.model_providers.models.llm.localai_model import LocalAIModel | |||
| from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError | |||
| @@ -27,6 +27,13 @@ class LocalAIProvider(BaseModelProvider): | |||
| def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]: | |||
| return [] | |||
| def _get_text_generation_model_mode(self, model_name) -> str: | |||
| credentials = self.get_model_credentials(model_name, ModelType.TEXT_GENERATION) | |||
| if credentials['completion_type'] == 'chat_completion': | |||
| return ModelMode.CHAT.value | |||
| else: | |||
| return ModelMode.COMPLETION.value | |||
| def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]: | |||
| """ | |||
| Returns the model class. | |||
| @@ -7,7 +7,7 @@ from langchain.schema import HumanMessage | |||
| from core.helper import encrypter | |||
| from core.model_providers.models.base import BaseProviderModel | |||
| from core.model_providers.models.embedding.minimax_embedding import MinimaxEmbedding | |||
| from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType | |||
| from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType, ModelMode | |||
| from core.model_providers.models.llm.minimax_model import MinimaxModel | |||
| from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError | |||
| from core.third_party.langchain.llms.minimax_llm import MinimaxChatLLM | |||
| @@ -29,10 +29,12 @@ class MinimaxProvider(BaseModelProvider): | |||
| { | |||
| 'id': 'abab5.5-chat', | |||
| 'name': 'abab5.5-chat', | |||
| 'mode': ModelMode.COMPLETION.value, | |||
| }, | |||
| { | |||
| 'id': 'abab5-chat', | |||
| 'name': 'abab5-chat', | |||
| 'mode': ModelMode.COMPLETION.value, | |||
| } | |||
| ] | |||
| elif model_type == ModelType.EMBEDDINGS: | |||
| @@ -45,6 +47,9 @@ class MinimaxProvider(BaseModelProvider): | |||
| else: | |||
| return [] | |||
| def _get_text_generation_model_mode(self, model_name) -> str: | |||
| return ModelMode.COMPLETION.value | |||
| def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]: | |||
| """ | |||
| Returns the model class. | |||
| @@ -13,8 +13,8 @@ from core.model_providers.models.entity.provider import ModelFeature | |||
| from core.model_providers.models.speech2text.openai_whisper import OpenAIWhisper | |||
| from core.model_providers.models.base import BaseProviderModel | |||
| from core.model_providers.models.embedding.openai_embedding import OpenAIEmbedding | |||
| from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType | |||
| from core.model_providers.models.llm.openai_model import OpenAIModel | |||
| from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType, ModelMode | |||
| from core.model_providers.models.llm.openai_model import OpenAIModel, COMPLETION_MODELS | |||
| from core.model_providers.models.moderation.openai_moderation import OpenAIModeration | |||
| from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError | |||
| from core.model_providers.providers.hosted import hosted_model_providers | |||
| @@ -36,6 +36,7 @@ class OpenAIProvider(BaseModelProvider): | |||
| { | |||
| 'id': 'gpt-3.5-turbo', | |||
| 'name': 'gpt-3.5-turbo', | |||
| 'mode': ModelMode.CHAT.value, | |||
| 'features': [ | |||
| ModelFeature.AGENT_THOUGHT.value | |||
| ] | |||
| @@ -43,10 +44,12 @@ class OpenAIProvider(BaseModelProvider): | |||
| { | |||
| 'id': 'gpt-3.5-turbo-instruct', | |||
| 'name': 'GPT-3.5-Turbo-Instruct', | |||
| 'mode': ModelMode.COMPLETION.value, | |||
| }, | |||
| { | |||
| 'id': 'gpt-3.5-turbo-16k', | |||
| 'name': 'gpt-3.5-turbo-16k', | |||
| 'mode': ModelMode.CHAT.value, | |||
| 'features': [ | |||
| ModelFeature.AGENT_THOUGHT.value | |||
| ] | |||
| @@ -54,6 +57,7 @@ class OpenAIProvider(BaseModelProvider): | |||
| { | |||
| 'id': 'gpt-4', | |||
| 'name': 'gpt-4', | |||
| 'mode': ModelMode.CHAT.value, | |||
| 'features': [ | |||
| ModelFeature.AGENT_THOUGHT.value | |||
| ] | |||
| @@ -61,6 +65,7 @@ class OpenAIProvider(BaseModelProvider): | |||
| { | |||
| 'id': 'gpt-4-32k', | |||
| 'name': 'gpt-4-32k', | |||
| 'mode': ModelMode.CHAT.value, | |||
| 'features': [ | |||
| ModelFeature.AGENT_THOUGHT.value | |||
| ] | |||
| @@ -68,6 +73,7 @@ class OpenAIProvider(BaseModelProvider): | |||
| { | |||
| 'id': 'text-davinci-003', | |||
| 'name': 'text-davinci-003', | |||
| 'mode': ModelMode.COMPLETION.value, | |||
| } | |||
| ] | |||
| @@ -100,6 +106,12 @@ class OpenAIProvider(BaseModelProvider): | |||
| else: | |||
| return [] | |||
| def _get_text_generation_model_mode(self, model_name) -> str: | |||
| if model_name in COMPLETION_MODELS: | |||
| return ModelMode.COMPLETION.value | |||
| else: | |||
| return ModelMode.CHAT.value | |||
| def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]: | |||
| """ | |||
| Returns the model class. | |||
| @@ -3,7 +3,7 @@ from typing import Type | |||
| from core.helper import encrypter | |||
| from core.model_providers.models.embedding.openllm_embedding import OpenLLMEmbedding | |||
| from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType | |||
| from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType, ModelMode | |||
| from core.model_providers.models.llm.openllm_model import OpenLLMModel | |||
| from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError | |||
| @@ -24,6 +24,9 @@ class OpenLLMProvider(BaseModelProvider): | |||
| def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]: | |||
| return [] | |||
| def _get_text_generation_model_mode(self, model_name) -> str: | |||
| return ModelMode.COMPLETION.value | |||
| def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]: | |||
| """ | |||
| Returns the model class. | |||
| @@ -6,7 +6,8 @@ import replicate | |||
| from replicate.exceptions import ReplicateError | |||
| from core.helper import encrypter | |||
| from core.model_providers.models.entity.model_params import KwargRule, KwargRuleType, ModelKwargsRules, ModelType | |||
| from core.model_providers.models.entity.model_params import KwargRule, KwargRuleType, ModelKwargsRules, ModelType, \ | |||
| ModelMode | |||
| from core.model_providers.models.llm.replicate_model import ReplicateModel | |||
| from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError | |||
| @@ -26,6 +27,9 @@ class ReplicateProvider(BaseModelProvider): | |||
| def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]: | |||
| return [] | |||
| def _get_text_generation_model_mode(self, model_name) -> str: | |||
| return ModelMode.CHAT.value if model_name.endswith('-chat') else ModelMode.COMPLETION.value | |||
| def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]: | |||
| """ | |||
| Returns the model class. | |||
| @@ -7,7 +7,7 @@ from langchain.schema import HumanMessage | |||
| from core.helper import encrypter | |||
| from core.model_providers.models.base import BaseProviderModel | |||
| from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType | |||
| from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType, ModelMode | |||
| from core.model_providers.models.llm.spark_model import SparkModel | |||
| from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError | |||
| from core.third_party.langchain.llms.spark import ChatSpark | |||
| @@ -30,15 +30,20 @@ class SparkProvider(BaseModelProvider): | |||
| { | |||
| 'id': 'spark', | |||
| 'name': 'Spark V1.5', | |||
| 'mode': ModelMode.CHAT.value, | |||
| }, | |||
| { | |||
| 'id': 'spark-v2', | |||
| 'name': 'Spark V2.0', | |||
| 'mode': ModelMode.CHAT.value, | |||
| } | |||
| ] | |||
| else: | |||
| return [] | |||
| def _get_text_generation_model_mode(self, model_name) -> str: | |||
| return ModelMode.CHAT.value | |||
| def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]: | |||
| """ | |||
| Returns the model class. | |||
| @@ -4,7 +4,7 @@ from typing import Type | |||
| from core.helper import encrypter | |||
| from core.model_providers.models.base import BaseProviderModel | |||
| from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType | |||
| from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType, ModelMode | |||
| from core.model_providers.models.llm.tongyi_model import TongyiModel | |||
| from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError | |||
| from core.third_party.langchain.llms.tongyi_llm import EnhanceTongyi | |||
| @@ -26,15 +26,20 @@ class TongyiProvider(BaseModelProvider): | |||
| { | |||
| 'id': 'qwen-turbo', | |||
| 'name': 'qwen-turbo', | |||
| 'mode': ModelMode.COMPLETION.value, | |||
| }, | |||
| { | |||
| 'id': 'qwen-plus', | |||
| 'name': 'qwen-plus', | |||
| 'mode': ModelMode.COMPLETION.value, | |||
| } | |||
| ] | |||
| else: | |||
| return [] | |||
| def _get_text_generation_model_mode(self, model_name) -> str: | |||
| return ModelMode.COMPLETION.value | |||
| def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]: | |||
| """ | |||
| Returns the model class. | |||
| @@ -4,7 +4,7 @@ from typing import Type | |||
| from core.helper import encrypter | |||
| from core.model_providers.models.base import BaseProviderModel | |||
| from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType | |||
| from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType, ModelMode | |||
| from core.model_providers.models.llm.wenxin_model import WenxinModel | |||
| from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError | |||
| from core.third_party.langchain.llms.wenxin import Wenxin | |||
| @@ -26,19 +26,25 @@ class WenxinProvider(BaseModelProvider): | |||
| { | |||
| 'id': 'ernie-bot', | |||
| 'name': 'ERNIE-Bot', | |||
| 'mode': ModelMode.COMPLETION.value, | |||
| }, | |||
| { | |||
| 'id': 'ernie-bot-turbo', | |||
| 'name': 'ERNIE-Bot-turbo', | |||
| 'mode': ModelMode.COMPLETION.value, | |||
| }, | |||
| { | |||
| 'id': 'bloomz-7b', | |||
| 'name': 'BLOOMZ-7B', | |||
| 'mode': ModelMode.COMPLETION.value, | |||
| } | |||
| ] | |||
| else: | |||
| return [] | |||
| def _get_text_generation_model_mode(self, model_name) -> str: | |||
| return ModelMode.COMPLETION.value | |||
| def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]: | |||
| """ | |||
| Returns the model class. | |||
| @@ -6,7 +6,7 @@ from langchain.embeddings import XinferenceEmbeddings | |||
| from core.helper import encrypter | |||
| from core.model_providers.models.embedding.xinference_embedding import XinferenceEmbedding | |||
| from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType | |||
| from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType, ModelMode | |||
| from core.model_providers.models.llm.xinference_model import XinferenceModel | |||
| from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError | |||
| @@ -26,6 +26,9 @@ class XinferenceProvider(BaseModelProvider): | |||
| def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]: | |||
| return [] | |||
| def _get_text_generation_model_mode(self, model_name) -> str: | |||
| return ModelMode.COMPLETION.value | |||
| def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]: | |||
| """ | |||
| Returns the model class. | |||
| @@ -7,7 +7,7 @@ from langchain.schema import HumanMessage | |||
| from core.helper import encrypter | |||
| from core.model_providers.models.base import BaseProviderModel | |||
| from core.model_providers.models.embedding.zhipuai_embedding import ZhipuAIEmbedding | |||
| from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType | |||
| from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType, ModelMode | |||
| from core.model_providers.models.llm.zhipuai_model import ZhipuAIModel | |||
| from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError | |||
| from core.third_party.langchain.llms.zhipuai_llm import ZhipuAIChatLLM | |||
| @@ -29,18 +29,22 @@ class ZhipuAIProvider(BaseModelProvider): | |||
| { | |||
| 'id': 'chatglm_pro', | |||
| 'name': 'chatglm_pro', | |||
| 'mode': ModelMode.CHAT.value, | |||
| }, | |||
| { | |||
| 'id': 'chatglm_std', | |||
| 'name': 'chatglm_std', | |||
| 'mode': ModelMode.CHAT.value, | |||
| }, | |||
| { | |||
| 'id': 'chatglm_lite', | |||
| 'name': 'chatglm_lite', | |||
| 'mode': ModelMode.CHAT.value, | |||
| }, | |||
| { | |||
| 'id': 'chatglm_lite_32k', | |||
| 'name': 'chatglm_lite_32k', | |||
| 'mode': ModelMode.CHAT.value, | |||
| } | |||
| ] | |||
| elif model_type == ModelType.EMBEDDINGS: | |||
| @@ -53,6 +57,9 @@ class ZhipuAIProvider(BaseModelProvider): | |||
| else: | |||
| return [] | |||
| def _get_text_generation_model_mode(self, model_name) -> str: | |||
| return ModelMode.CHAT.value | |||
| def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]: | |||
| """ | |||
| Returns the model class. | |||
| @@ -1,4 +1,3 @@ | |||
| import math | |||
| from typing import Optional | |||
| from langchain import WikipediaAPIWrapper | |||
| @@ -50,6 +49,7 @@ class OrchestratorRuleParser: | |||
| tool_configs = agent_mode_config.get('tools', []) | |||
| agent_provider_name = model_dict.get('provider', 'openai') | |||
| agent_model_name = model_dict.get('name', 'gpt-4') | |||
| dataset_configs = self.app_model_config.dataset_configs_dict | |||
| agent_model_instance = ModelFactory.get_text_generation_model( | |||
| tenant_id=self.tenant_id, | |||
| @@ -96,13 +96,14 @@ class OrchestratorRuleParser: | |||
| summary_model_instance = None | |||
| tools = self.to_tools( | |||
| agent_model_instance=agent_model_instance, | |||
| tool_configs=tool_configs, | |||
| callbacks=[agent_callback, DifyStdOutCallbackHandler()], | |||
| agent_model_instance=agent_model_instance, | |||
| conversation_message_task=conversation_message_task, | |||
| rest_tokens=rest_tokens, | |||
| callbacks=[agent_callback, DifyStdOutCallbackHandler()], | |||
| return_resource=return_resource, | |||
| retriever_from=retriever_from | |||
| retriever_from=retriever_from, | |||
| dataset_configs=dataset_configs | |||
| ) | |||
| if len(tools) == 0: | |||
| @@ -170,20 +171,12 @@ class OrchestratorRuleParser: | |||
| return None | |||
| def to_tools(self, agent_model_instance: BaseLLM, tool_configs: list, | |||
| conversation_message_task: ConversationMessageTask, | |||
| rest_tokens: int, callbacks: Callbacks = None, return_resource: bool = False, | |||
| retriever_from: str = 'dev') -> list[BaseTool]: | |||
| def to_tools(self, tool_configs: list, callbacks: Callbacks = None, **kwargs) -> list[BaseTool]: | |||
| """ | |||
| Convert app agent tool configs to tools | |||
| :param agent_model_instance: | |||
| :param rest_tokens: | |||
| :param tool_configs: app agent tool configs | |||
| :param conversation_message_task: | |||
| :param callbacks: | |||
| :param return_resource: | |||
| :param retriever_from: | |||
| :return: | |||
| """ | |||
| tools = [] | |||
| @@ -195,15 +188,15 @@ class OrchestratorRuleParser: | |||
| tool = None | |||
| if tool_type == "dataset": | |||
| tool = self.to_dataset_retriever_tool(tool_val, conversation_message_task, rest_tokens, return_resource, retriever_from) | |||
| tool = self.to_dataset_retriever_tool(tool_config=tool_val, **kwargs) | |||
| elif tool_type == "web_reader": | |||
| tool = self.to_web_reader_tool(agent_model_instance) | |||
| tool = self.to_web_reader_tool(tool_config=tool_val, **kwargs) | |||
| elif tool_type == "google_search": | |||
| tool = self.to_google_search_tool() | |||
| tool = self.to_google_search_tool(tool_config=tool_val, **kwargs) | |||
| elif tool_type == "wikipedia": | |||
| tool = self.to_wikipedia_tool() | |||
| tool = self.to_wikipedia_tool(tool_config=tool_val, **kwargs) | |||
| elif tool_type == "current_datetime": | |||
| tool = self.to_current_datetime_tool() | |||
| tool = self.to_current_datetime_tool(tool_config=tool_val, **kwargs) | |||
| if tool: | |||
| if tool.callbacks is not None: | |||
| @@ -215,12 +208,15 @@ class OrchestratorRuleParser: | |||
| return tools | |||
| def to_dataset_retriever_tool(self, tool_config: dict, conversation_message_task: ConversationMessageTask, | |||
| rest_tokens: int, return_resource: bool = False, retriever_from: str = 'dev') \ | |||
| dataset_configs: dict, rest_tokens: int, | |||
| return_resource: bool = False, retriever_from: str = 'dev', | |||
| **kwargs) \ | |||
| -> Optional[BaseTool]: | |||
| """ | |||
| A dataset tool is a tool that can be used to retrieve information from a dataset | |||
| :param rest_tokens: | |||
| :param tool_config: | |||
| :param dataset_configs: | |||
| :param conversation_message_task: | |||
| :param return_resource: | |||
| :param retriever_from: | |||
| @@ -238,10 +234,20 @@ class OrchestratorRuleParser: | |||
| if dataset and dataset.available_document_count == 0 and dataset.available_document_count == 0: | |||
| return None | |||
| k = self._dynamic_calc_retrieve_k(dataset, rest_tokens) | |||
| top_k = dataset_configs.get("top_k", 2) | |||
| # dynamically adjust top_k when the remaining token number is not enough to support top_k | |||
| top_k = self._dynamic_calc_retrieve_k(dataset=dataset, top_k=top_k, rest_tokens=rest_tokens) | |||
| score_threshold = None | |||
| score_threshold_config = dataset_configs.get("score_threshold") | |||
| if score_threshold_config and score_threshold_config.get("enable"): | |||
| score_threshold = score_threshold_config.get("value") | |||
| tool = DatasetRetrieverTool.from_dataset( | |||
| dataset=dataset, | |||
| k=k, | |||
| top_k=top_k, | |||
| score_threshold=score_threshold, | |||
| callbacks=[DatasetToolCallbackHandler(conversation_message_task)], | |||
| conversation_message_task=conversation_message_task, | |||
| return_resource=return_resource, | |||
| @@ -250,7 +256,7 @@ class OrchestratorRuleParser: | |||
| return tool | |||
| def to_web_reader_tool(self, agent_model_instance: BaseLLM) -> Optional[BaseTool]: | |||
| def to_web_reader_tool(self, tool_config: dict, agent_model_instance: BaseLLM, **kwargs) -> Optional[BaseTool]: | |||
| """ | |||
| A tool for reading web pages | |||
| @@ -278,7 +284,7 @@ class OrchestratorRuleParser: | |||
| return tool | |||
| def to_google_search_tool(self) -> Optional[BaseTool]: | |||
| def to_google_search_tool(self, tool_config: dict, **kwargs) -> Optional[BaseTool]: | |||
| tool_provider = SerpAPIToolProvider(tenant_id=self.tenant_id) | |||
| func_kwargs = tool_provider.credentials_to_func_kwargs() | |||
| if not func_kwargs: | |||
| @@ -296,12 +302,12 @@ class OrchestratorRuleParser: | |||
| return tool | |||
| def to_current_datetime_tool(self) -> Optional[BaseTool]: | |||
| def to_current_datetime_tool(self, tool_config: dict, **kwargs) -> Optional[BaseTool]: | |||
| tool = DatetimeTool() | |||
| return tool | |||
| def to_wikipedia_tool(self) -> Optional[BaseTool]: | |||
| def to_wikipedia_tool(self, tool_config: dict, **kwargs) -> Optional[BaseTool]: | |||
| class WikipediaInput(BaseModel): | |||
| query: str = Field(..., description="search query.") | |||
| @@ -312,22 +318,18 @@ class OrchestratorRuleParser: | |||
| ) | |||
| @classmethod | |||
| def _dynamic_calc_retrieve_k(cls, dataset: Dataset, rest_tokens: int) -> int: | |||
| DEFAULT_K = 2 | |||
| CONTEXT_TOKENS_PERCENT = 0.3 | |||
| MAX_K = 10 | |||
| def _dynamic_calc_retrieve_k(cls, dataset: Dataset, top_k: int, rest_tokens: int) -> int: | |||
| if rest_tokens == -1: | |||
| return DEFAULT_K | |||
| return top_k | |||
| processing_rule = dataset.latest_process_rule | |||
| if not processing_rule: | |||
| return DEFAULT_K | |||
| return top_k | |||
| if processing_rule.mode == "custom": | |||
| rules = processing_rule.rules_dict | |||
| if not rules: | |||
| return DEFAULT_K | |||
| return top_k | |||
| segmentation = rules["segmentation"] | |||
| segment_max_tokens = segmentation["max_tokens"] | |||
| @@ -335,14 +337,7 @@ class OrchestratorRuleParser: | |||
| segment_max_tokens = DatasetProcessRule.AUTOMATIC_RULES['segmentation']['max_tokens'] | |||
| # when rest_tokens is less than default context tokens | |||
| if rest_tokens < segment_max_tokens * DEFAULT_K: | |||
| if rest_tokens < segment_max_tokens * top_k: | |||
| return rest_tokens // segment_max_tokens | |||
| context_limit_tokens = math.floor(rest_tokens * CONTEXT_TOKENS_PERCENT) | |||
| # when context_limit_tokens is less than default context tokens, use default_k | |||
| if context_limit_tokens <= segment_max_tokens * DEFAULT_K: | |||
| return DEFAULT_K | |||
| # Expand the k value when there's still some room left in the 30% rest tokens space, but less than the MAX_K | |||
| return min(context_limit_tokens // segment_max_tokens, MAX_K) | |||
| return min(top_k, 10) | |||
| @@ -0,0 +1,79 @@ | |||
| CONTEXT = "Use the following context as your learned knowledge, inside <context></context> XML tags.\n\n<context>\n{{#context#}}\n</context>\n\nWhen answer to user:\n- If you don't know, just say that you don't know.\n- If you don't know when you are not sure, ask for clarification.\nAvoid mentioning that you obtained the information from the context.\nAnd answer according to the language of the user's question.\n" | |||
| BAICHUAN_CONTEXT = "用户在与一个客观的助手对话。助手会尊重找到的材料,给出全面专业的解释,但不会过度演绎。同时回答中不会暴露引用的材料:\n\n```\n{{#context#}}\n```\n\n" | |||
| CHAT_APP_COMPLETION_PROMPT_CONFIG = { | |||
| "completion_prompt_config": { | |||
| "prompt": { | |||
| "text": "{{#pre_prompt#}}\nHere is the chat histories between human and assistant, inside <histories></histories> XML tags.\n\n<histories>\n{{#histories#}}\n</histories>\n\n\nHuman: {{#query#}}\n\nAssistant: " | |||
| }, | |||
| "conversation_histories_role": { | |||
| "user_prefix": "Human", | |||
| "assistant_prefix": "Assistant" | |||
| } | |||
| } | |||
| } | |||
| CHAT_APP_CHAT_PROMPT_CONFIG = { | |||
| "chat_prompt_config": { | |||
| "prompt": [{ | |||
| "role": "system", | |||
| "text": "{{#pre_prompt#}}" | |||
| }] | |||
| } | |||
| } | |||
| COMPLETION_APP_CHAT_PROMPT_CONFIG = { | |||
| "chat_prompt_config": { | |||
| "prompt": [{ | |||
| "role": "user", | |||
| "text": "{{#pre_prompt#}}" | |||
| }] | |||
| } | |||
| } | |||
| COMPLETION_APP_COMPLETION_PROMPT_CONFIG = { | |||
| "completion_prompt_config": { | |||
| "prompt": { | |||
| "text": "{{#pre_prompt#}}" | |||
| } | |||
| } | |||
| } | |||
| BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG = { | |||
| "completion_prompt_config": { | |||
| "prompt": { | |||
| "text": "{{#pre_prompt#}}\n\n用户和助手的历史对话内容如下:\n```\n{{#histories#}}\n```\n\n\n\n用户:{{#query#}}" | |||
| }, | |||
| "conversation_histories_role": { | |||
| "user_prefix": "用户", | |||
| "assistant_prefix": "助手" | |||
| } | |||
| } | |||
| } | |||
| BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG = { | |||
| "chat_prompt_config": { | |||
| "prompt": [{ | |||
| "role": "system", | |||
| "text": "{{#pre_prompt#}}" | |||
| }] | |||
| } | |||
| } | |||
| BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG = { | |||
| "chat_prompt_config": { | |||
| "prompt": [{ | |||
| "role": "user", | |||
| "text": "{{#pre_prompt#}}" | |||
| }] | |||
| } | |||
| } | |||
| BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG = { | |||
| "completion_prompt_config": { | |||
| "prompt": { | |||
| "text": "{{#pre_prompt#}}" | |||
| } | |||
| } | |||
| } | |||
| @@ -1,38 +1,24 @@ | |||
| import re | |||
| from langchain.schema import BaseMessage, SystemMessage, AIMessage, HumanMessage | |||
| from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, AIMessagePromptTemplate | |||
| from langchain.schema import BaseMessage | |||
| from core.prompt.prompt_template import JinjaPromptTemplate | |||
| from core.prompt.prompt_template import PromptTemplateParser | |||
| class PromptBuilder: | |||
| @classmethod | |||
| def parse_prompt(cls, prompt: str, inputs: dict) -> str: | |||
| prompt_template = PromptTemplateParser(prompt) | |||
| prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} | |||
| prompt = prompt_template.format(prompt_inputs) | |||
| return prompt | |||
| @classmethod | |||
| def to_system_message(cls, prompt_content: str, inputs: dict) -> BaseMessage: | |||
| prompt_template = JinjaPromptTemplate.from_template(prompt_content) | |||
| system_prompt_template = SystemMessagePromptTemplate(prompt=prompt_template) | |||
| prompt_inputs = {k: inputs[k] for k in system_prompt_template.input_variables if k in inputs} | |||
| system_message = system_prompt_template.format(**prompt_inputs) | |||
| return system_message | |||
| return SystemMessage(content=cls.parse_prompt(prompt_content, inputs)) | |||
| @classmethod | |||
| def to_ai_message(cls, prompt_content: str, inputs: dict) -> BaseMessage: | |||
| prompt_template = JinjaPromptTemplate.from_template(prompt_content) | |||
| ai_prompt_template = AIMessagePromptTemplate(prompt=prompt_template) | |||
| prompt_inputs = {k: inputs[k] for k in ai_prompt_template.input_variables if k in inputs} | |||
| ai_message = ai_prompt_template.format(**prompt_inputs) | |||
| return ai_message | |||
| return AIMessage(content=cls.parse_prompt(prompt_content, inputs)) | |||
| @classmethod | |||
| def to_human_message(cls, prompt_content: str, inputs: dict) -> BaseMessage: | |||
| prompt_template = JinjaPromptTemplate.from_template(prompt_content) | |||
| human_prompt_template = HumanMessagePromptTemplate(prompt=prompt_template) | |||
| human_message = human_prompt_template.format(**inputs) | |||
| return human_message | |||
| @classmethod | |||
| def process_template(cls, template: str): | |||
| processed_template = re.sub(r'\{{2}(.+)\}{2}', r'{\1}', template) | |||
| # processed_template = re.sub(r'\{([a-zA-Z_]\w+?)\}', r'\1', template) | |||
| # processed_template = re.sub(r'\{\{([a-zA-Z_]\w+?)\}\}', r'{\1}', processed_template) | |||
| return processed_template | |||
| return HumanMessage(content=cls.parse_prompt(prompt_content, inputs)) | |||
| @@ -1,79 +1,39 @@ | |||
| import re | |||
| from typing import Any | |||
| from jinja2 import Environment, meta | |||
| from langchain import PromptTemplate | |||
| from langchain.formatting import StrictFormatter | |||
| REGEX = re.compile(r"\{\{([a-zA-Z_][a-zA-Z0-9_]{1,29}|#histories#|#query#|#context#)\}\}") | |||
| class JinjaPromptTemplate(PromptTemplate): | |||
| template_format: str = "jinja2" | |||
| """The format of the prompt template. Options are: 'f-string', 'jinja2'.""" | |||
| class PromptTemplateParser: | |||
| """ | |||
| Rules: | |||
| @classmethod | |||
| def from_template(cls, template: str, **kwargs: Any) -> PromptTemplate: | |||
| """Load a prompt template from a template.""" | |||
| env = Environment() | |||
| template = template.replace("{{}}", "{}") | |||
| ast = env.parse(template) | |||
| input_variables = meta.find_undeclared_variables(ast) | |||
| if "partial_variables" in kwargs: | |||
| partial_variables = kwargs["partial_variables"] | |||
| input_variables = { | |||
| var for var in input_variables if var not in partial_variables | |||
| } | |||
| return cls( | |||
| input_variables=list(sorted(input_variables)), template=template, **kwargs | |||
| ) | |||
| class OutLinePromptTemplate(PromptTemplate): | |||
| @classmethod | |||
| def from_template(cls, template: str, **kwargs: Any) -> PromptTemplate: | |||
| """Load a prompt template from a template.""" | |||
| input_variables = { | |||
| v for _, v, _, _ in OneLineFormatter().parse(template) if v is not None | |||
| } | |||
| return cls( | |||
| input_variables=list(sorted(input_variables)), template=template, **kwargs | |||
| ) | |||
| def format(self, **kwargs: Any) -> str: | |||
| """Format the prompt with the inputs. | |||
| 1. Template variables must be enclosed in `{{}}`. | |||
| 2. The template variable Key can only be: letters + numbers + underscore, with a maximum length of 16 characters, | |||
| and can only start with letters and underscores. | |||
| 3. The template variable Key cannot contain new lines or spaces, and must comply with rule 2. | |||
| 4. In addition to the above, 3 types of special template variable Keys are accepted: | |||
| `{{#histories#}}` `{{#query#}}` `{{#context#}}`. No other `{{##}}` template variables are allowed. | |||
| """ | |||
| Args: | |||
| kwargs: Any arguments to be passed to the prompt template. | |||
| def __init__(self, template: str): | |||
| self.template = template | |||
| self.variable_keys = self.extract() | |||
| Returns: | |||
| A formatted string. | |||
| def extract(self) -> list: | |||
| # Regular expression to match the template rules | |||
| return re.findall(REGEX, self.template) | |||
| Example: | |||
| def format(self, inputs: dict, remove_template_variables: bool = True) -> str: | |||
| def replacer(match): | |||
| key = match.group(1) | |||
| value = inputs.get(key, match.group(0)) # return original matched string if key not found | |||
| .. code-block:: python | |||
| if remove_template_variables: | |||
| return PromptTemplateParser.remove_template_variables(value) | |||
| return value | |||
| prompt.format(variable1="foo") | |||
| """ | |||
| kwargs = self._merge_partial_and_user_variables(**kwargs) | |||
| return OneLineFormatter().format(self.template, **kwargs) | |||
| return re.sub(REGEX, replacer, self.template) | |||
| class OneLineFormatter(StrictFormatter): | |||
| def parse(self, format_string): | |||
| last_end = 0 | |||
| results = [] | |||
| for match in re.finditer(r"{([a-zA-Z_]\w*)}", format_string): | |||
| field_name = match.group(1) | |||
| start, end = match.span() | |||
| literal_text = format_string[last_end:start] | |||
| last_end = end | |||
| results.append((literal_text, field_name, '', None)) | |||
| remaining_literal_text = format_string[last_end:] | |||
| if remaining_literal_text: | |||
| results.append((remaining_literal_text, None, None, None)) | |||
| return results | |||
| @classmethod | |||
| def remove_template_variables(cls, text: str): | |||
| return re.sub(REGEX, r'{\1}', text) | |||
| @@ -61,36 +61,6 @@ User Input: yo, 你今天咋样? | |||
| User Input: | |||
| """ | |||
| CONVERSATION_SUMMARY_PROMPT = ( | |||
| "Please generate a short summary of the following conversation.\n" | |||
| "If the following conversation communicating in English, you should only return an English summary.\n" | |||
| "If the following conversation communicating in Chinese, you should only return a Chinese summary.\n" | |||
| "[Conversation Start]\n" | |||
| "{context}\n" | |||
| "[Conversation End]\n\n" | |||
| "summary:" | |||
| ) | |||
| INTRODUCTION_GENERATE_PROMPT = ( | |||
| "I am designing a product for users to interact with an AI through dialogue. " | |||
| "The Prompt given to the AI before the conversation is:\n\n" | |||
| "```\n{prompt}\n```\n\n" | |||
| "Please generate a brief introduction of no more than 50 words that greets the user, based on this Prompt. " | |||
| "Do not reveal the developer's motivation or deep logic behind the Prompt, " | |||
| "but focus on building a relationship with the user:\n" | |||
| ) | |||
| MORE_LIKE_THIS_GENERATE_PROMPT = ( | |||
| "-----\n" | |||
| "{original_completion}\n" | |||
| "-----\n\n" | |||
| "Please use the above content as a sample for generating the result, " | |||
| "and include key information points related to the original sample in the result. " | |||
| "Try to rephrase this information in different ways and predict according to the rules below.\n\n" | |||
| "-----\n" | |||
| "{prompt}\n" | |||
| ) | |||
| SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT = ( | |||
| "Please help me predict the three most likely questions that human would ask, " | |||
| "and keeping each question under 20 characters.\n" | |||
| @@ -157,10 +127,10 @@ and fill in variables, with a welcome sentence, and keep TLDR. | |||
| ``` | |||
| << MY INTENDED AUDIENCES >> | |||
| {audiences} | |||
| {{audiences}} | |||
| << HOPING TO SOLVE >> | |||
| {hoping_to_solve} | |||
| {{hoping_to_solve}} | |||
| << OUTPUT >> | |||
| """ | |||
| @@ -1,5 +1,5 @@ | |||
| import json | |||
| from typing import Type | |||
| from typing import Type, Optional | |||
| from flask import current_app | |||
| from langchain.tools import BaseTool | |||
| @@ -28,7 +28,8 @@ class DatasetRetrieverTool(BaseTool): | |||
| tenant_id: str | |||
| dataset_id: str | |||
| k: int = 3 | |||
| top_k: int = 2 | |||
| score_threshold: Optional[float] = None | |||
| conversation_message_task: ConversationMessageTask | |||
| return_resource: bool | |||
| retriever_from: str | |||
| @@ -66,7 +67,7 @@ class DatasetRetrieverTool(BaseTool): | |||
| ) | |||
| ) | |||
| documents = kw_table_index.search(query, search_kwargs={'k': self.k}) | |||
| documents = kw_table_index.search(query, search_kwargs={'k': self.top_k}) | |||
| return str("\n".join([document.page_content for document in documents])) | |||
| else: | |||
| @@ -80,20 +81,21 @@ class DatasetRetrieverTool(BaseTool): | |||
| return '' | |||
| except ProviderTokenNotInitError: | |||
| return '' | |||
| embeddings = CacheEmbedding(embedding_model) | |||
| embeddings = CacheEmbedding(embedding_model) | |||
| vector_index = VectorIndex( | |||
| dataset=dataset, | |||
| config=current_app.config, | |||
| embeddings=embeddings | |||
| ) | |||
| if self.k > 0: | |||
| if self.top_k > 0: | |||
| documents = vector_index.search( | |||
| query, | |||
| search_type='similarity_score_threshold', | |||
| search_kwargs={ | |||
| 'k': self.k, | |||
| 'k': self.top_k, | |||
| 'score_threshold': self.score_threshold, | |||
| 'filter': { | |||
| 'group_id': [dataset.id] | |||
| } | |||
| @@ -4,5 +4,4 @@ from .clean_when_document_deleted import handle | |||
| from .clean_when_dataset_deleted import handle | |||
| from .update_app_dataset_join_when_app_model_config_updated import handle | |||
| from .generate_conversation_name_when_first_message_created import handle | |||
| from .generate_conversation_summary_when_few_message_created import handle | |||
| from .create_document_index import handle | |||
| @@ -1,14 +0,0 @@ | |||
| from events.message_event import message_was_created | |||
| from tasks.generate_conversation_summary_task import generate_conversation_summary_task | |||
| @message_was_created.connect | |||
| def handle(sender, **kwargs): | |||
| message = sender | |||
| conversation = kwargs.get('conversation') | |||
| is_first_message = kwargs.get('is_first_message') | |||
| if not is_first_message and conversation.mode == 'chat' and not conversation.summary: | |||
| history_message_count = conversation.message_count | |||
| if history_message_count >= 5: | |||
| generate_conversation_summary_task.delay(conversation.id) | |||
| @@ -28,6 +28,10 @@ model_config_fields = { | |||
| 'dataset_query_variable': fields.String, | |||
| 'pre_prompt': fields.String, | |||
| 'agent_mode': fields.Raw(attribute='agent_mode_dict'), | |||
| 'prompt_type': fields.String, | |||
| 'chat_prompt_config': fields.Raw(attribute='chat_prompt_config_dict'), | |||
| 'completion_prompt_config': fields.Raw(attribute='completion_prompt_config_dict'), | |||
| 'dataset_configs': fields.Raw(attribute='dataset_configs_dict') | |||
| } | |||
| app_detail_fields = { | |||
| @@ -123,6 +123,7 @@ conversation_with_summary_fields = { | |||
| 'from_end_user_id': fields.String, | |||
| 'from_end_user_session_id': fields.String, | |||
| 'from_account_id': fields.String, | |||
| 'name': fields.String, | |||
| 'summary': fields.String(attribute='summary_or_query'), | |||
| 'read_at': TimestampField, | |||
| 'created_at': TimestampField, | |||
| @@ -0,0 +1,37 @@ | |||
| """add advanced prompt templates | |||
| Revision ID: b3a09c049e8e | |||
| Revises: 2e9819ca5b28 | |||
| Create Date: 2023-10-10 15:23:23.395420 | |||
| """ | |||
| from alembic import op | |||
| import sqlalchemy as sa | |||
| # revision identifiers, used by Alembic. | |||
| revision = 'b3a09c049e8e' | |||
| down_revision = '2e9819ca5b28' | |||
| branch_labels = None | |||
| depends_on = None | |||
| def upgrade(): | |||
| # ### commands auto generated by Alembic - please adjust! ### | |||
| with op.batch_alter_table('app_model_configs', schema=None) as batch_op: | |||
| batch_op.add_column(sa.Column('prompt_type', sa.String(length=255), nullable=False, server_default='simple')) | |||
| batch_op.add_column(sa.Column('chat_prompt_config', sa.Text(), nullable=True)) | |||
| batch_op.add_column(sa.Column('completion_prompt_config', sa.Text(), nullable=True)) | |||
| batch_op.add_column(sa.Column('dataset_configs', sa.Text(), nullable=True)) | |||
| # ### end Alembic commands ### | |||
| def downgrade(): | |||
| # ### commands auto generated by Alembic - please adjust! ### | |||
| with op.batch_alter_table('app_model_configs', schema=None) as batch_op: | |||
| batch_op.drop_column('dataset_configs') | |||
| batch_op.drop_column('completion_prompt_config') | |||
| batch_op.drop_column('chat_prompt_config') | |||
| batch_op.drop_column('prompt_type') | |||
| # ### end Alembic commands ### | |||
| @@ -93,6 +93,10 @@ class AppModelConfig(db.Model): | |||
| agent_mode = db.Column(db.Text) | |||
| sensitive_word_avoidance = db.Column(db.Text) | |||
| retriever_resource = db.Column(db.Text) | |||
| prompt_type = db.Column(db.String(255), nullable=False, default='simple') | |||
| chat_prompt_config = db.Column(db.Text) | |||
| completion_prompt_config = db.Column(db.Text) | |||
| dataset_configs = db.Column(db.Text) | |||
| @property | |||
| def app(self): | |||
| @@ -139,6 +143,18 @@ class AppModelConfig(db.Model): | |||
| def agent_mode_dict(self) -> dict: | |||
| return json.loads(self.agent_mode) if self.agent_mode else {"enabled": False, "strategy": None, "tools": []} | |||
| @property | |||
| def chat_prompt_config_dict(self) -> dict: | |||
| return json.loads(self.chat_prompt_config) if self.chat_prompt_config else {} | |||
| @property | |||
| def completion_prompt_config_dict(self) -> dict: | |||
| return json.loads(self.completion_prompt_config) if self.completion_prompt_config else {} | |||
| @property | |||
| def dataset_configs_dict(self) -> dict: | |||
| return json.loads(self.dataset_configs) if self.dataset_configs else {"top_k": 2, "score_threshold": {"enable": False}} | |||
| def to_dict(self) -> dict: | |||
| return { | |||
| "provider": "", | |||
| @@ -155,7 +171,11 @@ class AppModelConfig(db.Model): | |||
| "user_input_form": self.user_input_form_list, | |||
| "dataset_query_variable": self.dataset_query_variable, | |||
| "pre_prompt": self.pre_prompt, | |||
| "agent_mode": self.agent_mode_dict | |||
| "agent_mode": self.agent_mode_dict, | |||
| "prompt_type": self.prompt_type, | |||
| "chat_prompt_config": self.chat_prompt_config_dict, | |||
| "completion_prompt_config": self.completion_prompt_config_dict, | |||
| "dataset_configs": self.dataset_configs_dict | |||
| } | |||
| def from_model_config_dict(self, model_config: dict): | |||
| @@ -177,6 +197,13 @@ class AppModelConfig(db.Model): | |||
| self.agent_mode = json.dumps(model_config['agent_mode']) | |||
| self.retriever_resource = json.dumps(model_config['retriever_resource']) \ | |||
| if model_config.get('retriever_resource') else None | |||
| self.prompt_type = model_config.get('prompt_type', 'simple') | |||
| self.chat_prompt_config = json.dumps(model_config.get('chat_prompt_config')) \ | |||
| if model_config.get('chat_prompt_config') else None | |||
| self.completion_prompt_config = json.dumps(model_config.get('completion_prompt_config')) \ | |||
| if model_config.get('completion_prompt_config') else None | |||
| self.dataset_configs = json.dumps(model_config.get('dataset_configs')) \ | |||
| if model_config.get('dataset_configs') else None | |||
| return self | |||
| def copy(self): | |||
| @@ -197,7 +224,11 @@ class AppModelConfig(db.Model): | |||
| dataset_query_variable=self.dataset_query_variable, | |||
| pre_prompt=self.pre_prompt, | |||
| agent_mode=self.agent_mode, | |||
| retriever_resource=self.retriever_resource | |||
| retriever_resource=self.retriever_resource, | |||
| prompt_type=self.prompt_type, | |||
| chat_prompt_config=self.chat_prompt_config, | |||
| completion_prompt_config=self.completion_prompt_config, | |||
| dataset_configs=self.dataset_configs | |||
| ) | |||
| return new_app_model_config | |||
| @@ -0,0 +1,56 @@ | |||
| import copy | |||
| from core.prompt.advanced_prompt_templates import CHAT_APP_COMPLETION_PROMPT_CONFIG, CHAT_APP_CHAT_PROMPT_CONFIG, COMPLETION_APP_CHAT_PROMPT_CONFIG, COMPLETION_APP_COMPLETION_PROMPT_CONFIG, \ | |||
| BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG, BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG, BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG, BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG, CONTEXT, BAICHUAN_CONTEXT | |||
| class AdvancedPromptTemplateService: | |||
| def get_prompt(self, args: dict) -> dict: | |||
| app_mode = args['app_mode'] | |||
| model_mode = args['model_mode'] | |||
| model_name = args['model_name'] | |||
| has_context = args['has_context'] | |||
| if 'baichuan' in model_name: | |||
| return self.get_baichuan_prompt(app_mode, model_mode, has_context) | |||
| else: | |||
| return self.get_common_prompt(app_mode, model_mode, has_context) | |||
| def get_common_prompt(self, app_mode: str, model_mode:str, has_context: bool) -> dict: | |||
| if app_mode == 'chat': | |||
| if model_mode == 'completion': | |||
| return self.get_completion_prompt(copy.deepcopy(CHAT_APP_COMPLETION_PROMPT_CONFIG), has_context, CONTEXT) | |||
| elif model_mode == 'chat': | |||
| return self.get_chat_prompt(copy.deepcopy(CHAT_APP_CHAT_PROMPT_CONFIG), has_context, CONTEXT) | |||
| elif app_mode == 'completion': | |||
| if model_mode == 'completion': | |||
| return self.get_completion_prompt(copy.deepcopy(COMPLETION_APP_COMPLETION_PROMPT_CONFIG), has_context, CONTEXT) | |||
| elif model_mode == 'chat': | |||
| return self.get_chat_prompt(copy.deepcopy(COMPLETION_APP_CHAT_PROMPT_CONFIG), has_context, CONTEXT) | |||
| def get_completion_prompt(self, prompt_template: str, has_context: bool, context: str) -> dict: | |||
| if has_context == 'true': | |||
| prompt_template['completion_prompt_config']['prompt']['text'] = context + prompt_template['completion_prompt_config']['prompt']['text'] | |||
| return prompt_template | |||
| def get_chat_prompt(self, prompt_template: str, has_context: bool, context: str) -> dict: | |||
| if has_context == 'true': | |||
| prompt_template['chat_prompt_config']['prompt'][0]['text'] = context + prompt_template['chat_prompt_config']['prompt'][0]['text'] | |||
| return prompt_template | |||
| def get_baichuan_prompt(self, app_mode: str, model_mode:str, has_context: bool) -> dict: | |||
| if app_mode == 'chat': | |||
| if model_mode == 'completion': | |||
| return self.get_completion_prompt(copy.deepcopy(BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG), has_context, BAICHUAN_CONTEXT) | |||
| elif model_mode == 'chat': | |||
| return self.get_chat_prompt(copy.deepcopy(BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG), has_context, BAICHUAN_CONTEXT) | |||
| elif app_mode == 'completion': | |||
| if model_mode == 'completion': | |||
| return self.get_completion_prompt(copy.deepcopy(BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG), has_context, BAICHUAN_CONTEXT) | |||
| elif model_mode == 'chat': | |||
| return self.get_chat_prompt(copy.deepcopy(BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG), has_context, BAICHUAN_CONTEXT) | |||
| @@ -3,7 +3,7 @@ import uuid | |||
| from core.agent.agent_executor import PlanningStrategy | |||
| from core.model_providers.model_provider_factory import ModelProviderFactory | |||
| from core.model_providers.models.entity.model_params import ModelType | |||
| from core.model_providers.models.entity.model_params import ModelType, ModelMode | |||
| from models.account import Account | |||
| from services.dataset_service import DatasetService | |||
| @@ -34,40 +34,28 @@ class AppModelConfigService: | |||
| # max_tokens | |||
| if 'max_tokens' not in cp: | |||
| cp["max_tokens"] = 512 | |||
| # | |||
| # if not isinstance(cp["max_tokens"], int) or cp["max_tokens"] <= 0 or cp["max_tokens"] > \ | |||
| # llm_constant.max_context_token_length[model_name]: | |||
| # raise ValueError( | |||
| # "max_tokens must be an integer greater than 0 " | |||
| # "and not exceeding the maximum value of the corresponding model") | |||
| # | |||
| # temperature | |||
| if 'temperature' not in cp: | |||
| cp["temperature"] = 1 | |||
| # | |||
| # if not isinstance(cp["temperature"], (float, int)) or cp["temperature"] < 0 or cp["temperature"] > 2: | |||
| # raise ValueError("temperature must be a float between 0 and 2") | |||
| # | |||
| # top_p | |||
| if 'top_p' not in cp: | |||
| cp["top_p"] = 1 | |||
| # if not isinstance(cp["top_p"], (float, int)) or cp["top_p"] < 0 or cp["top_p"] > 2: | |||
| # raise ValueError("top_p must be a float between 0 and 2") | |||
| # | |||
| # presence_penalty | |||
| if 'presence_penalty' not in cp: | |||
| cp["presence_penalty"] = 0 | |||
| # if not isinstance(cp["presence_penalty"], (float, int)) or cp["presence_penalty"] < -2 or cp["presence_penalty"] > 2: | |||
| # raise ValueError("presence_penalty must be a float between -2 and 2") | |||
| # | |||
| # presence_penalty | |||
| if 'frequency_penalty' not in cp: | |||
| cp["frequency_penalty"] = 0 | |||
| # if not isinstance(cp["frequency_penalty"], (float, int)) or cp["frequency_penalty"] < -2 or cp["frequency_penalty"] > 2: | |||
| # raise ValueError("frequency_penalty must be a float between -2 and 2") | |||
| # stop | |||
| if 'stop' not in cp: | |||
| cp["stop"] = [] | |||
| elif not isinstance(cp["stop"], list): | |||
| raise ValueError("stop in model.completion_params must be of list type") | |||
| # Filter out extra parameters | |||
| filtered_cp = { | |||
| @@ -75,7 +63,8 @@ class AppModelConfigService: | |||
| "temperature": cp["temperature"], | |||
| "top_p": cp["top_p"], | |||
| "presence_penalty": cp["presence_penalty"], | |||
| "frequency_penalty": cp["frequency_penalty"] | |||
| "frequency_penalty": cp["frequency_penalty"], | |||
| "stop": cp["stop"] | |||
| } | |||
| return filtered_cp | |||
| @@ -211,6 +200,10 @@ class AppModelConfigService: | |||
| model_ids = [m['id'] for m in model_list] | |||
| if config["model"]["name"] not in model_ids: | |||
| raise ValueError("model.name must be in the specified model list") | |||
| # model.mode | |||
| if 'mode' not in config['model'] or not config['model']["mode"]: | |||
| config['model']["mode"] = "" | |||
| # model.completion_params | |||
| if 'completion_params' not in config["model"]: | |||
| @@ -339,6 +332,9 @@ class AppModelConfigService: | |||
| # dataset_query_variable | |||
| AppModelConfigService.is_dataset_query_variable_valid(config, mode) | |||
| # advanced prompt validation | |||
| AppModelConfigService.is_advanced_prompt_valid(config, mode) | |||
| # Filter out extra parameters | |||
| filtered_config = { | |||
| "opening_statement": config["opening_statement"], | |||
| @@ -351,12 +347,17 @@ class AppModelConfigService: | |||
| "model": { | |||
| "provider": config["model"]["provider"], | |||
| "name": config["model"]["name"], | |||
| "mode": config['model']["mode"], | |||
| "completion_params": config["model"]["completion_params"] | |||
| }, | |||
| "user_input_form": config["user_input_form"], | |||
| "dataset_query_variable": config.get('dataset_query_variable'), | |||
| "pre_prompt": config["pre_prompt"], | |||
| "agent_mode": config["agent_mode"] | |||
| "agent_mode": config["agent_mode"], | |||
| "prompt_type": config["prompt_type"], | |||
| "chat_prompt_config": config["chat_prompt_config"], | |||
| "completion_prompt_config": config["completion_prompt_config"], | |||
| "dataset_configs": config["dataset_configs"] | |||
| } | |||
| return filtered_config | |||
| @@ -375,4 +376,51 @@ class AppModelConfigService: | |||
| if dataset_exists and not dataset_query_variable: | |||
| raise ValueError("Dataset query variable is required when dataset is exist") | |||
| @staticmethod | |||
| def is_advanced_prompt_valid(config: dict, app_mode: str) -> None: | |||
| # prompt_type | |||
| if 'prompt_type' not in config or not config["prompt_type"]: | |||
| config["prompt_type"] = "simple" | |||
| if config['prompt_type'] not in ['simple', 'advanced']: | |||
| raise ValueError("prompt_type must be in ['simple', 'advanced']") | |||
| # chat_prompt_config | |||
| if 'chat_prompt_config' not in config or not config["chat_prompt_config"]: | |||
| config["chat_prompt_config"] = {} | |||
| if not isinstance(config["chat_prompt_config"], dict): | |||
| raise ValueError("chat_prompt_config must be of object type") | |||
| # completion_prompt_config | |||
| if 'completion_prompt_config' not in config or not config["completion_prompt_config"]: | |||
| config["completion_prompt_config"] = {} | |||
| if not isinstance(config["completion_prompt_config"], dict): | |||
| raise ValueError("completion_prompt_config must be of object type") | |||
| # dataset_configs | |||
| if 'dataset_configs' not in config or not config["dataset_configs"]: | |||
| config["dataset_configs"] = {"top_k": 2, "score_threshold": {"enable": False}} | |||
| if not isinstance(config["dataset_configs"], dict): | |||
| raise ValueError("dataset_configs must be of object type") | |||
| if config['prompt_type'] == 'advanced': | |||
| if not config['chat_prompt_config'] and not config['completion_prompt_config']: | |||
| raise ValueError("chat_prompt_config or completion_prompt_config is required when prompt_type is advanced") | |||
| if config['model']["mode"] not in ['chat', 'completion']: | |||
| raise ValueError("model.mode must be in ['chat', 'completion'] when prompt_type is advanced") | |||
| if app_mode == 'chat' and config['model']["mode"] == ModelMode.COMPLETION.value: | |||
| user_prefix = config['completion_prompt_config']['conversation_histories_role']['user_prefix'] | |||
| assistant_prefix = config['completion_prompt_config']['conversation_histories_role']['assistant_prefix'] | |||
| if not user_prefix: | |||
| config['completion_prompt_config']['conversation_histories_role']['user_prefix'] = 'Human' | |||
| if not assistant_prefix: | |||
| config['completion_prompt_config']['conversation_histories_role']['assistant_prefix'] = 'Assistant' | |||
| @@ -244,7 +244,8 @@ class CompletionService: | |||
| @classmethod | |||
| def generate_more_like_this(cls, app_model: App, user: Union[Account | EndUser], | |||
| message_id: str, streaming: bool = True) -> Union[dict | Generator]: | |||
| message_id: str, streaming: bool = True, | |||
| retriever_from: str = 'dev') -> Union[dict | Generator]: | |||
| if not user: | |||
| raise ValueError('user cannot be None') | |||
| @@ -266,14 +267,11 @@ class CompletionService: | |||
| raise MoreLikeThisDisabledError() | |||
| app_model_config = message.app_model_config | |||
| if message.override_model_configs: | |||
| override_model_configs = json.loads(message.override_model_configs) | |||
| pre_prompt = override_model_configs.get("pre_prompt", '') | |||
| elif app_model_config: | |||
| pre_prompt = app_model_config.pre_prompt | |||
| else: | |||
| raise AppModelConfigBrokenError() | |||
| model_dict = app_model_config.model_dict | |||
| completion_params = model_dict.get('completion_params') | |||
| completion_params['temperature'] = 0.9 | |||
| model_dict['completion_params'] = completion_params | |||
| app_model_config.model = json.dumps(model_dict) | |||
| generate_task_id = str(uuid.uuid4()) | |||
| @@ -282,58 +280,28 @@ class CompletionService: | |||
| user = cls.get_real_user_instead_of_proxy_obj(user) | |||
| generate_worker_thread = threading.Thread(target=cls.generate_more_like_this_worker, kwargs={ | |||
| generate_worker_thread = threading.Thread(target=cls.generate_worker, kwargs={ | |||
| 'flask_app': current_app._get_current_object(), | |||
| 'generate_task_id': generate_task_id, | |||
| 'detached_app_model': app_model, | |||
| 'app_model_config': app_model_config, | |||
| 'detached_message': message, | |||
| 'pre_prompt': pre_prompt, | |||
| 'query': message.query, | |||
| 'inputs': message.inputs, | |||
| 'detached_user': user, | |||
| 'streaming': streaming | |||
| 'detached_conversation': None, | |||
| 'streaming': streaming, | |||
| 'is_model_config_override': True, | |||
| 'retriever_from': retriever_from | |||
| }) | |||
| generate_worker_thread.start() | |||
| cls.countdown_and_close(current_app._get_current_object(), generate_worker_thread, pubsub, user, generate_task_id) | |||
| # wait for 10 minutes to close the thread | |||
| cls.countdown_and_close(current_app._get_current_object(), generate_worker_thread, pubsub, user, | |||
| generate_task_id) | |||
| return cls.compact_response(pubsub, streaming) | |||
| @classmethod | |||
| def generate_more_like_this_worker(cls, flask_app: Flask, generate_task_id: str, detached_app_model: App, | |||
| app_model_config: AppModelConfig, detached_message: Message, pre_prompt: str, | |||
| detached_user: Union[Account, EndUser], streaming: bool): | |||
| with flask_app.app_context(): | |||
| # fixed the state of the model object when it detached from the original session | |||
| user = db.session.merge(detached_user) | |||
| app_model = db.session.merge(detached_app_model) | |||
| message = db.session.merge(detached_message) | |||
| try: | |||
| # run | |||
| Completion.generate_more_like_this( | |||
| task_id=generate_task_id, | |||
| app=app_model, | |||
| user=user, | |||
| message=message, | |||
| pre_prompt=pre_prompt, | |||
| app_model_config=app_model_config, | |||
| streaming=streaming | |||
| ) | |||
| except ConversationTaskStoppedException: | |||
| pass | |||
| except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, | |||
| LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, | |||
| ModelCurrentlyNotSupportError) as e: | |||
| PubHandler.pub_error(user, generate_task_id, e) | |||
| except LLMAuthorizationError: | |||
| PubHandler.pub_error(user, generate_task_id, LLMAuthorizationError('Incorrect API key provided')) | |||
| except Exception as e: | |||
| logging.exception("Unknown Error in completion") | |||
| PubHandler.pub_error(user, generate_task_id, e) | |||
| finally: | |||
| db.session.commit() | |||
| @classmethod | |||
| def get_cleaned_inputs(cls, user_inputs: dict, app_model_config: AppModelConfig): | |||
| if user_inputs is None: | |||
| @@ -482,6 +482,9 @@ class ProviderService: | |||
| 'features': [] | |||
| } | |||
| if 'mode' in model: | |||
| valid_model_dict['model_mode'] = model['mode'] | |||
| if 'features' in model: | |||
| valid_model_dict['features'] = model['features'] | |||
| @@ -1,55 +0,0 @@ | |||
| import logging | |||
| import time | |||
| import click | |||
| from celery import shared_task | |||
| from werkzeug.exceptions import NotFound | |||
| from core.generator.llm_generator import LLMGenerator | |||
| from core.model_providers.error import LLMError, ProviderTokenNotInitError | |||
| from extensions.ext_database import db | |||
| from models.model import Conversation, Message | |||
| @shared_task(queue='generation') | |||
| def generate_conversation_summary_task(conversation_id: str): | |||
| """ | |||
| Async Generate conversation summary | |||
| :param conversation_id: | |||
| Usage: generate_conversation_summary_task.delay(conversation_id) | |||
| """ | |||
| logging.info(click.style('Start generate conversation summary: {}'.format(conversation_id), fg='green')) | |||
| start_at = time.perf_counter() | |||
| conversation = db.session.query(Conversation).filter(Conversation.id == conversation_id).first() | |||
| if not conversation: | |||
| raise NotFound('Conversation not found') | |||
| try: | |||
| # get conversation messages count | |||
| history_message_count = conversation.message_count | |||
| if history_message_count >= 5 and not conversation.summary: | |||
| app_model = conversation.app | |||
| if not app_model: | |||
| return | |||
| history_messages = db.session.query(Message).filter(Message.conversation_id == conversation.id) \ | |||
| .order_by(Message.created_at.asc()).all() | |||
| conversation.summary = LLMGenerator.generate_conversation_summary(app_model.tenant_id, history_messages) | |||
| db.session.add(conversation) | |||
| db.session.commit() | |||
| except (LLMError, ProviderTokenNotInitError): | |||
| conversation.summary = '[No Summary]' | |||
| db.session.commit() | |||
| pass | |||
| except Exception as e: | |||
| conversation.summary = '[No Summary]' | |||
| db.session.commit() | |||
| logging.exception(e) | |||
| end_at = time.perf_counter() | |||
| logging.info( | |||
| click.style('Conversation summary generated: {} latency: {}'.format(conversation_id, end_at - start_at), | |||
| fg='green')) | |||
| @@ -44,7 +44,7 @@ def decrypt_side_effect(tenant_id, encrypted_api_key): | |||
| def test_get_num_tokens(mock_decrypt): | |||
| model = get_mock_model('claude-2') | |||
| rst = model.get_num_tokens([ | |||
| PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?') | |||
| PromptMessage(type=MessageType.USER, content='Who is your manufacturer?') | |||
| ]) | |||
| assert rst == 6 | |||
| @@ -69,7 +69,7 @@ def test_chat_get_num_tokens(mock_decrypt, mocker): | |||
| openai_model = get_mock_azure_openai_model('gpt-35-turbo', mocker) | |||
| rst = openai_model.get_num_tokens([ | |||
| PromptMessage(type=MessageType.SYSTEM, content='you are a kindness Assistant.'), | |||
| PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?') | |||
| PromptMessage(type=MessageType.USER, content='Who is your manufacturer?') | |||
| ]) | |||
| assert rst == 22 | |||
| @@ -48,7 +48,7 @@ def test_chat_get_num_tokens(mock_decrypt): | |||
| model = get_mock_model('baichuan2-53b') | |||
| rst = model.get_num_tokens([ | |||
| PromptMessage(type=MessageType.SYSTEM, content='you are a kindness Assistant.'), | |||
| PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?') | |||
| PromptMessage(type=MessageType.USER, content='Who is your manufacturer?') | |||
| ]) | |||
| assert rst > 0 | |||
| @@ -59,7 +59,7 @@ def test_chat_run(mock_decrypt, mocker): | |||
| model = get_mock_model('baichuan2-53b') | |||
| messages = [ | |||
| PromptMessage(type=MessageType.HUMAN, content='Are you Human? you MUST only answer `y` or `n`?') | |||
| PromptMessage(type=MessageType.USER, content='Are you Human? you MUST only answer `y` or `n`?') | |||
| ] | |||
| rst = model.run( | |||
| messages, | |||
| @@ -73,7 +73,7 @@ def test_chat_stream_run(mock_decrypt, mocker): | |||
| model = get_mock_model('baichuan2-53b', streaming=True) | |||
| messages = [ | |||
| PromptMessage(type=MessageType.HUMAN, content='Are you Human? you MUST only answer `y` or `n`?') | |||
| PromptMessage(type=MessageType.USER, content='Are you Human? you MUST only answer `y` or `n`?') | |||
| ] | |||
| rst = model.run( | |||
| messages | |||
| @@ -71,7 +71,7 @@ def test_hosted_inference_api_get_num_tokens(mock_decrypt, mock_model_info, mock | |||
| mocker | |||
| ) | |||
| rst = model.get_num_tokens([ | |||
| PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?') | |||
| PromptMessage(type=MessageType.USER, content='Who is your manufacturer?') | |||
| ]) | |||
| assert rst == 5 | |||
| @@ -88,7 +88,7 @@ def test_inference_endpoints_get_num_tokens(mock_decrypt, mock_model_info, mocke | |||
| mocker | |||
| ) | |||
| rst = model.get_num_tokens([ | |||
| PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?') | |||
| PromptMessage(type=MessageType.USER, content='Who is your manufacturer?') | |||
| ]) | |||
| assert rst == 5 | |||
| @@ -48,7 +48,7 @@ def decrypt_side_effect(tenant_id, encrypted_api_key): | |||
| def test_get_num_tokens(mock_decrypt): | |||
| model = get_mock_model('abab5.5-chat') | |||
| rst = model.get_num_tokens([ | |||
| PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?') | |||
| PromptMessage(type=MessageType.USER, content='Who is your manufacturer?') | |||
| ]) | |||
| assert rst == 5 | |||
| @@ -52,7 +52,7 @@ def test_chat_get_num_tokens(mock_decrypt): | |||
| openai_model = get_mock_openai_model('gpt-3.5-turbo') | |||
| rst = openai_model.get_num_tokens([ | |||
| PromptMessage(type=MessageType.SYSTEM, content='you are a kindness Assistant.'), | |||
| PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?') | |||
| PromptMessage(type=MessageType.USER, content='Who is your manufacturer?') | |||
| ]) | |||
| assert rst == 22 | |||
| @@ -55,7 +55,7 @@ def decrypt_side_effect(tenant_id, encrypted_api_key): | |||
| def test_get_num_tokens(mock_decrypt, mocker): | |||
| model = get_mock_model('facebook/opt-125m', mocker) | |||
| rst = model.get_num_tokens([ | |||
| PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?') | |||
| PromptMessage(type=MessageType.USER, content='Who is your manufacturer?') | |||
| ]) | |||
| assert rst == 5 | |||
| @@ -58,7 +58,7 @@ def decrypt_side_effect(tenant_id, encrypted_api_key): | |||
| def test_get_num_tokens(mock_decrypt, mocker): | |||
| model = get_mock_model('a16z-infra/llama-2-13b-chat', '2a7f981751ec7fdf87b5b91ad4db53683a98082e9ff7bfd12c8cd5ea85980a52', mocker) | |||
| rst = model.get_num_tokens([ | |||
| PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?') | |||
| PromptMessage(type=MessageType.USER, content='Who is your manufacturer?') | |||
| ]) | |||
| assert rst == 7 | |||
| @@ -52,7 +52,7 @@ def decrypt_side_effect(tenant_id, encrypted_api_key): | |||
| def test_get_num_tokens(mock_decrypt): | |||
| model = get_mock_model('spark') | |||
| rst = model.get_num_tokens([ | |||
| PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?') | |||
| PromptMessage(type=MessageType.USER, content='Who is your manufacturer?') | |||
| ]) | |||
| assert rst == 6 | |||
| @@ -46,7 +46,7 @@ def decrypt_side_effect(tenant_id, encrypted_api_key): | |||
| def test_get_num_tokens(mock_decrypt): | |||
| model = get_mock_model('qwen-turbo') | |||
| rst = model.get_num_tokens([ | |||
| PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?') | |||
| PromptMessage(type=MessageType.USER, content='Who is your manufacturer?') | |||
| ]) | |||
| assert rst == 5 | |||
| @@ -46,7 +46,7 @@ def decrypt_side_effect(tenant_id, encrypted_api_key): | |||
| def test_get_num_tokens(mock_decrypt): | |||
| model = get_mock_model('ernie-bot') | |||
| rst = model.get_num_tokens([ | |||
| PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?') | |||
| PromptMessage(type=MessageType.USER, content='Who is your manufacturer?') | |||
| ]) | |||
| assert rst == 5 | |||
| @@ -57,7 +57,7 @@ def decrypt_side_effect(tenant_id, encrypted_api_key): | |||
| def test_get_num_tokens(mock_decrypt, mocker): | |||
| model = get_mock_model('llama-2-chat', mocker) | |||
| rst = model.get_num_tokens([ | |||
| PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?') | |||
| PromptMessage(type=MessageType.USER, content='Who is your manufacturer?') | |||
| ]) | |||
| assert rst == 5 | |||
| @@ -46,7 +46,7 @@ def test_chat_get_num_tokens(mock_decrypt): | |||
| model = get_mock_model('chatglm_lite') | |||
| rst = model.get_num_tokens([ | |||
| PromptMessage(type=MessageType.SYSTEM, content='you are a kindness Assistant.'), | |||
| PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?') | |||
| PromptMessage(type=MessageType.USER, content='Who is your manufacturer?') | |||
| ]) | |||
| assert rst > 0 | |||
| @@ -57,7 +57,7 @@ def test_chat_run(mock_decrypt, mocker): | |||
| model = get_mock_model('chatglm_lite') | |||
| messages = [ | |||
| PromptMessage(type=MessageType.HUMAN, content='Are you Human? you MUST only answer `y` or `n`?') | |||
| PromptMessage(type=MessageType.USER, content='Are you Human? you MUST only answer `y` or `n`?') | |||
| ] | |||
| rst = model.run( | |||
| messages, | |||
| @@ -71,7 +71,7 @@ def test_chat_stream_run(mock_decrypt, mocker): | |||
| model = get_mock_model('chatglm_lite', streaming=True) | |||
| messages = [ | |||
| PromptMessage(type=MessageType.HUMAN, content='Are you Human? you MUST only answer `y` or `n`?') | |||
| PromptMessage(type=MessageType.USER, content='Are you Human? you MUST only answer `y` or `n`?') | |||
| ] | |||
| rst = model.run( | |||
| messages | |||
| @@ -1,7 +1,7 @@ | |||
| from typing import Type | |||
| from core.model_providers.models.base import BaseProviderModel | |||
| from core.model_providers.models.entity.model_params import ModelType, ModelKwargsRules | |||
| from core.model_providers.models.entity.model_params import ModelType, ModelKwargsRules, ModelMode | |||
| from core.model_providers.models.llm.openai_model import OpenAIModel | |||
| from core.model_providers.providers.base import BaseModelProvider | |||
| @@ -12,7 +12,10 @@ class FakeModelProvider(BaseModelProvider): | |||
| return 'fake' | |||
| def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]: | |||
| return [{'id': 'test_model', 'name': 'Test Model'}] | |||
| return [{'id': 'test_model', 'name': 'Test Model', 'mode': 'completion'}] | |||
| def _get_text_generation_model_mode(self, model_name) -> str: | |||
| return ModelMode.COMPLETION.value | |||
| def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]: | |||
| return OpenAIModel | |||
| @@ -24,7 +24,7 @@ def test_get_supported_model_list(mocker): | |||
| provider = FakeModelProvider(provider=Provider()) | |||
| result = provider.get_supported_model_list(ModelType.TEXT_GENERATION) | |||
| assert result == [{'id': 'test_model', 'name': 'test_model'}] | |||
| assert result == [{'id': 'test_model', 'name': 'test_model', 'mode': 'completion'}] | |||
| def test_check_quota_over_limit(mocker): | |||