### What problem does this PR solve? Fix: The max tokens defined by the tenant are not used (#4297) (#2817) ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) --------- Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>tags/v0.17.0
| @@ -29,7 +29,7 @@ from api.db.db_models import Dialog, DB | |||
| from api.db.services.common_service import CommonService | |||
| from api.db.services.document_service import DocumentService | |||
| from api.db.services.knowledgebase_service import KnowledgebaseService | |||
| from api.db.services.llm_service import LLMService, TenantLLMService, LLMBundle | |||
| from api.db.services.llm_service import TenantLLMService, LLMBundle | |||
| from api import settings | |||
| from graphrag.utils import get_tags_from_cache, set_tags_to_cache | |||
| from rag.app.resume import forbidden_select_fields4resume | |||
| @@ -172,21 +172,12 @@ def chat(dialog, messages, stream=True, **kwargs): | |||
| chat_start_ts = timer() | |||
| # Get llm model name and model provider name | |||
| llm_id, model_provider = TenantLLMService.split_model_name_and_factory(dialog.llm_id) | |||
| # Get llm model instance by model and provide name | |||
| llm = LLMService.query(llm_name=llm_id) if not model_provider else LLMService.query(llm_name=llm_id, fid=model_provider) | |||
| if not llm: | |||
| # Model name is provided by tenant, but not system built-in | |||
| llm = TenantLLMService.query(tenant_id=dialog.tenant_id, llm_name=llm_id) if not model_provider else \ | |||
| TenantLLMService.query(tenant_id=dialog.tenant_id, llm_name=llm_id, llm_factory=model_provider) | |||
| if not llm: | |||
| raise LookupError("LLM(%s) not found" % dialog.llm_id) | |||
| max_tokens = 8192 | |||
| if llm_id2llm_type(dialog.llm_id) == "image2text": | |||
| llm_model_config = TenantLLMService.get_model_config(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id) | |||
| else: | |||
| max_tokens = llm[0].max_tokens | |||
| llm_model_config = TenantLLMService.get_model_config(dialog.tenant_id, LLMType.CHAT, dialog.llm_id) | |||
| max_tokens = llm_model_config.get("max_tokens", 8192) | |||
| check_llm_ts = timer() | |||
| @@ -86,8 +86,7 @@ class TenantLLMService(CommonService): | |||
| @classmethod | |||
| @DB.connection_context() | |||
| def model_instance(cls, tenant_id, llm_type, | |||
| llm_name=None, lang="Chinese"): | |||
| def get_model_config(cls, tenant_id, llm_type, llm_name=None): | |||
| e, tenant = TenantService.get_by_id(tenant_id) | |||
| if not e: | |||
| raise LookupError("Tenant not found") | |||
| @@ -124,7 +123,13 @@ class TenantLLMService(CommonService): | |||
| if not mdlnm: | |||
| raise LookupError(f"Type of {llm_type} model is not set.") | |||
| raise LookupError("Model({}) not authorized".format(mdlnm)) | |||
| return model_config | |||
| @classmethod | |||
| @DB.connection_context() | |||
| def model_instance(cls, tenant_id, llm_type, | |||
| llm_name=None, lang="Chinese"): | |||
| model_config = TenantLLMService.get_model_config(tenant_id, llm_type, llm_name) | |||
| if llm_type == LLMType.EMBEDDING.value: | |||
| if model_config["llm_factory"] not in EmbeddingModel: | |||
| return | |||
| @@ -228,10 +233,8 @@ class LLMBundle(object): | |||
| tenant_id, llm_type, llm_name, lang=lang) | |||
| assert self.mdl, "Can't find model for {}/{}/{}".format( | |||
| tenant_id, llm_type, llm_name) | |||
| self.max_length = 8192 | |||
| for lm in LLMService.query(llm_name=llm_name): | |||
| self.max_length = lm.max_tokens | |||
| break | |||
| model_config = TenantLLMService.get_model_config(tenant_id, llm_type, llm_name) | |||
| self.max_length = model_config.get("max_tokens", 8192) | |||
| def encode(self, texts: list): | |||
| embeddings, used_tokens = self.mdl.encode(texts) | |||