### What problem does this PR solve? Fix: The max tokens defined by the tenant are not used (#4297) (#2817) ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) --------- Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>tags/v0.17.0
| from api.db.services.common_service import CommonService | from api.db.services.common_service import CommonService | ||||
| from api.db.services.document_service import DocumentService | from api.db.services.document_service import DocumentService | ||||
| from api.db.services.knowledgebase_service import KnowledgebaseService | from api.db.services.knowledgebase_service import KnowledgebaseService | ||||
| from api.db.services.llm_service import LLMService, TenantLLMService, LLMBundle | |||||
| from api.db.services.llm_service import TenantLLMService, LLMBundle | |||||
| from api import settings | from api import settings | ||||
| from graphrag.utils import get_tags_from_cache, set_tags_to_cache | from graphrag.utils import get_tags_from_cache, set_tags_to_cache | ||||
| from rag.app.resume import forbidden_select_fields4resume | from rag.app.resume import forbidden_select_fields4resume | ||||
| chat_start_ts = timer() | chat_start_ts = timer() | ||||
| # Get llm model name and model provider name | |||||
| llm_id, model_provider = TenantLLMService.split_model_name_and_factory(dialog.llm_id) | |||||
| # Get llm model instance by model and provide name | |||||
| llm = LLMService.query(llm_name=llm_id) if not model_provider else LLMService.query(llm_name=llm_id, fid=model_provider) | |||||
| if not llm: | |||||
| # Model name is provided by tenant, but not system built-in | |||||
| llm = TenantLLMService.query(tenant_id=dialog.tenant_id, llm_name=llm_id) if not model_provider else \ | |||||
| TenantLLMService.query(tenant_id=dialog.tenant_id, llm_name=llm_id, llm_factory=model_provider) | |||||
| if not llm: | |||||
| raise LookupError("LLM(%s) not found" % dialog.llm_id) | |||||
| max_tokens = 8192 | |||||
| if llm_id2llm_type(dialog.llm_id) == "image2text": | |||||
| llm_model_config = TenantLLMService.get_model_config(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id) | |||||
| else: | else: | ||||
| max_tokens = llm[0].max_tokens | |||||
| llm_model_config = TenantLLMService.get_model_config(dialog.tenant_id, LLMType.CHAT, dialog.llm_id) | |||||
| max_tokens = llm_model_config.get("max_tokens", 8192) | |||||
| check_llm_ts = timer() | check_llm_ts = timer() | ||||
| @classmethod | @classmethod | ||||
| @DB.connection_context() | @DB.connection_context() | ||||
| def model_instance(cls, tenant_id, llm_type, | |||||
| llm_name=None, lang="Chinese"): | |||||
| def get_model_config(cls, tenant_id, llm_type, llm_name=None): | |||||
| e, tenant = TenantService.get_by_id(tenant_id) | e, tenant = TenantService.get_by_id(tenant_id) | ||||
| if not e: | if not e: | ||||
| raise LookupError("Tenant not found") | raise LookupError("Tenant not found") | ||||
| if not mdlnm: | if not mdlnm: | ||||
| raise LookupError(f"Type of {llm_type} model is not set.") | raise LookupError(f"Type of {llm_type} model is not set.") | ||||
| raise LookupError("Model({}) not authorized".format(mdlnm)) | raise LookupError("Model({}) not authorized".format(mdlnm)) | ||||
| return model_config | |||||
| @classmethod | |||||
| @DB.connection_context() | |||||
| def model_instance(cls, tenant_id, llm_type, | |||||
| llm_name=None, lang="Chinese"): | |||||
| model_config = TenantLLMService.get_model_config(tenant_id, llm_type, llm_name) | |||||
| if llm_type == LLMType.EMBEDDING.value: | if llm_type == LLMType.EMBEDDING.value: | ||||
| if model_config["llm_factory"] not in EmbeddingModel: | if model_config["llm_factory"] not in EmbeddingModel: | ||||
| return | return | ||||
| tenant_id, llm_type, llm_name, lang=lang) | tenant_id, llm_type, llm_name, lang=lang) | ||||
| assert self.mdl, "Can't find model for {}/{}/{}".format( | assert self.mdl, "Can't find model for {}/{}/{}".format( | ||||
| tenant_id, llm_type, llm_name) | tenant_id, llm_type, llm_name) | ||||
| self.max_length = 8192 | |||||
| for lm in LLMService.query(llm_name=llm_name): | |||||
| self.max_length = lm.max_tokens | |||||
| break | |||||
| model_config = TenantLLMService.get_model_config(tenant_id, llm_type, llm_name) | |||||
| self.max_length = model_config.get("max_tokens", 8192) | |||||
| def encode(self, texts: list): | def encode(self, texts: list): | ||||
| embeddings, used_tokens = self.mdl.encode(texts) | embeddings, used_tokens = self.mdl.encode(texts) |